16
"""BERT finetuning runner."""
18
from __future__ import absolute_import, division, print_function, unicode_literals
30
from torch.utils.data import DataLoader, Dataset, RandomSampler
31
from torch.utils.data.distributed import DistributedSampler
32
from tqdm import tqdm, trange
34
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
35
from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
36
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
38
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
39
datefmt='%m/%d/%Y %H:%M:%S',
41
logger = logging.getLogger(__name__)
43
def load_GeneralDomain(dir_data_out):
45
print("Load CLS.pt and train.json")
49
docs = torch.load(dir_data_out+"opendomain_CLS.pt")
50
with open(dir_data_out+"opendomain.json") as file:
52
docs = torch.load(dir_data_out+"train_CLS.pt")
55
with open(dir_data_out+"train.json") as file:
57
data = json.load(file)
58
print("train.json Done")
66
docs, data = load_GeneralDomain("data/yelp/")
68
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
70
with open(data_dir_indomain+"train.json") as file:
71
data = json.load(file)
74
num_label_list = list()
75
label_sentence_dict = dict()
80
num_label_list.append(line["aspect"])
82
num_label = sorted(list(set(num_label_list)))
83
label_map = {label : i for i , label in enumerate(num_label)}
92
all_input_ids = list()
93
all_input_mask = list()
94
all_segment_ids = list()
95
all_lm_labels_ids = list()
97
all_tail_idxs = list()
98
all_sentence_labels = list()
100
cur_tensors_list = list()
105
sentence = line["sentence"]
106
label = line["aspect"]
109
tokens_a = tokenizer.tokenize(sentence)
112
if "</s>" in tokens_a:
113
print("Have more than 1 </s>")
114
#tokens_a[tokens_a.index("<s>")] = "s"
115
for i in range(len(tokens_a)):
116
if tokens_a[i] == "</s>":
122
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
124
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
126
cur_tensors = (torch.tensor(cur_features.input_ids),
127
torch.tensor(cur_features.input_ids_org),
128
torch.tensor(cur_features.input_mask),
129
torch.tensor(cur_features.segment_ids),
130
torch.tensor(cur_features.lm_label_ids),
132
torch.tensor(cur_features.tail_idxs),
133
torch.tensor(label_map[label]))
135
cur_tensors_list.append(cur_tensors)
138
all_input_ids.append(torch.tensor(cur_features.input_ids))
139
all_input_mask.append(torch.tensor(cur_features.input_mask))
140
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
141
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
142
all_is_next.append(torch.tensor(0))
143
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
144
all_sentence_labels.append(torch.tensor(label_map[label]))
146
cur_tensors = (torch.stack(all_input_ids),
147
torch.stack(all_input_mask),
148
torch.stack(all_segment_ids),
149
torch.stack(all_lm_labels_ids),
150
torch.stack(all_is_next),
151
torch.stack(all_tail_idxs),
152
torch.stack(all_sentence_labels))
156
return cur_tensors_list
159
def in_Domain_Task_Data_binary(data_dir_indomain, tokenizer, max_seq_length):
161
with open(data_dir_indomain+"train.json") as file:
162
data = json.load(file)
165
num_label_list = list()
166
label_sentence_dict = dict()
171
num_label_list.append(line["aspect"])
173
label_sentence_dict[line["aspect"]].append([line["sentence"]])
175
label_sentence_dict[line["aspect"]] = [line["sentence"]]
177
num_label = sorted(list(set(num_label_list)))
178
label_map = {label : i for i , label in enumerate(num_label)}
181
all_cur_tensors = list()
186
sentence = line["sentence"]
187
label = line["aspect"]
188
sentence_out = [(random.choice(label_sentence_dict[label_out])[0], label_out) for label_out in num_label if label_out!=label]
189
all_sentence = [(sentence, label)] + sentence_out
191
all_input_ids = list()
192
all_input_mask = list()
193
all_segment_ids = list()
194
all_lm_labels_ids = list()
196
all_tail_idxs = list()
197
all_sentence_labels = list()
198
for id, sentence_label in enumerate(all_sentence):
200
tokens_a = tokenizer.tokenize(sentence_label[0])
202
if "</s>" in tokens_a:
203
print("Have more than 1 </s>")
204
for i in range(len(tokens_a)):
205
if tokens_a[i] == "</s>":
210
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
212
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
214
all_input_ids.append(torch.tensor(cur_features.input_ids))
215
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
216
all_input_mask.append(torch.tensor(cur_features.input_mask))
217
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
218
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
219
all_is_next.append(torch.tensor(0))
220
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
221
all_sentence_labels.append(torch.tensor(label_map[sentence_label[1]]))
223
cur_tensors = (torch.stack(all_input_ids),
224
torch.stack(all_input_ids_org),
225
torch.stack(all_input_mask),
226
torch.stack(all_segment_ids),
227
torch.stack(all_lm_labels_ids),
228
torch.stack(all_is_next),
229
torch.stack(all_tail_idxs),
230
torch.stack(all_sentence_labels))
232
all_cur_tensors.append(cur_tensors)
234
return all_cur_tensors
238
def AugmentationData_Domain(top_k, tokenizer, max_seq_length):
239
top_k_shape = top_k.indices.shape
240
ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
242
all_input_ids = list()
243
all_input_ids_org = list()
244
all_input_mask = list()
245
all_segment_ids = list()
246
all_lm_labels_ids = list()
248
all_tail_idxs = list()
250
for id, i in enumerate(ids):
251
t1 = data[str(i)]['sentence']
254
tokens_a = tokenizer.tokenize(t1)
256
if "</s>" in tokens_a:
257
print("Have more than 1 </s>")
258
#tokens_a[tokens_a.index("<s>")] = "s"
259
for i in range(len(tokens_a)):
260
if tokens_a[i] == "</s>":
265
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
268
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
270
all_input_ids.append(torch.tensor(cur_features.input_ids))
271
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
272
all_input_mask.append(torch.tensor(cur_features.input_mask))
273
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
274
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
275
all_is_next.append(torch.tensor(0))
276
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
279
cur_tensors = (torch.stack(all_input_ids),
280
torch.stack(all_input_ids_org),
281
torch.stack(all_input_mask),
282
torch.stack(all_segment_ids),
283
torch.stack(all_lm_labels_ids),
284
torch.stack(all_is_next),
285
torch.stack(all_tail_idxs))
291
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
292
top_k_shape = top_k.indices.shape
293
sentence_ids = top_k.indices
295
all_input_ids = list()
296
all_input_ids_org = list()
297
all_input_mask = list()
298
all_segment_ids = list()
299
all_lm_labels_ids = list()
301
all_tail_idxs = list()
302
all_sentence_labels = list()
304
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
306
for id_1, sent in enumerate(sentence_ids):
307
for id_2, sent_id in enumerate(sent):
309
t1 = data[str(int(sent_id))]['sentence']
312
tokens_a = tokenizer.tokenize(t1)
314
if "</s>" in tokens_a:
315
print("Have more than 1 </s>")
316
#tokens_a[tokens_a.index("<s>")] = "s"
317
for i in range(len(tokens_a)):
318
if tokens_a[i] == "</s>":
324
cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
327
cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
329
all_input_ids.append(torch.tensor(cur_features.input_ids))
330
all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
331
all_input_mask.append(torch.tensor(cur_features.input_mask))
332
all_segment_ids.append(torch.tensor(cur_features.segment_ids))
333
all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
334
all_is_next.append(torch.tensor(0))
335
all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
336
all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
338
all_input_ids.append(input_ids_[id_1])
339
all_input_ids_org.append(input_ids_org_[id_1])
340
all_input_mask.append(input_mask_[id_1])
341
all_segment_ids.append(segment_ids_[id_1])
342
all_lm_labels_ids.append(lm_label_ids_[id_1])
343
all_is_next.append(is_next_[id_1])
344
all_tail_idxs.append(tail_idxs_[id_1])
345
all_sentence_labels.append(sentence_label_[id_1])
348
cur_tensors = (torch.stack(all_input_ids),
349
torch.stack(all_input_ids_org),
350
torch.stack(all_input_mask),
351
torch.stack(all_segment_ids),
352
torch.stack(all_lm_labels_ids),
353
torch.stack(all_is_next),
354
torch.stack(all_tail_idxs),
355
torch.stack(all_sentence_labels)
363
class Dataset_noNext(Dataset):
364
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
366
self.vocab_size = tokenizer.vocab_size
367
self.tokenizer = tokenizer
368
self.seq_len = seq_len
369
self.on_memory = on_memory
370
self.corpus_lines = corpus_lines
371
self.corpus_path = corpus_path
372
self.encoding = encoding
376
self.sample_counter = 0
377
self.line_buffer = None
380
self.current_random_doc = 0
382
self.sample_to_doc = []
388
self.corpus_lines = 0
389
with open(corpus_path, "r", encoding=encoding) as f:
390
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
393
self.all_docs.append(doc)
396
self.sample_to_doc.pop()
399
sample = {"doc_id": len(self.all_docs),
401
self.sample_to_doc.append(sample)
403
self.corpus_lines = self.corpus_lines + 1
406
if self.all_docs[-1] != doc:
407
self.all_docs.append(doc)
408
self.sample_to_doc.pop()
410
self.num_docs = len(self.all_docs)
414
if self.corpus_lines is None:
415
with open(corpus_path, "r", encoding=encoding) as f:
416
self.corpus_lines = 0
417
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
418
if line.strip() == "":
421
self.corpus_lines += 1
424
if line.strip() != "":
427
self.file = open(corpus_path, "r", encoding=encoding)
428
self.random_file = open(corpus_path, "r", encoding=encoding)
432
return self.corpus_lines - self.num_docs - 1
434
def __getitem__(self, item):
435
cur_id = self.sample_counter
436
self.sample_counter += 1
437
if not self.on_memory:
439
if cur_id != 0 and (cur_id % len(self) == 0):
441
self.file = open(self.corpus_path, "r", encoding=self.encoding)
444
t1, is_next_label = self.random_sent(item)
445
if is_next_label == None:
450
tokens_a = tokenizer.tokenize(t1)
452
if "</s>" in tokens_a:
453
print("Have more than 1 </s>")
454
#tokens_a[tokens_a.index("<s>")] = "s"
455
for i in range(len(tokens_a)):
456
if tokens_a[i] == "</s>":
462
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
465
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
467
cur_tensors = (torch.tensor(cur_features.input_ids),
468
torch.tensor(cur_features.input_ids_org),
469
torch.tensor(cur_features.input_mask),
470
torch.tensor(cur_features.segment_ids),
471
torch.tensor(cur_features.lm_label_ids),
472
torch.tensor(cur_features.is_next),
473
torch.tensor(cur_features.tail_idxs))
477
def random_sent(self, index):
479
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
480
from one doc. With 50% the second sentence will be a random one from another doc.
481
:param index: int, index of sample.
482
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
484
t1, t2 = self.get_corpus_line(index)
487
def get_corpus_line(self, item):
489
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
490
:param item: int, index of sample.
491
:return: (str, str), two subsequent sentences from corpus
495
assert item < self.corpus_lines
497
sample = self.sample_to_doc[item]
498
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
500
self.current_doc = sample["doc_id"]
504
if self.line_buffer is None:
507
t1 = next(self.file).strip()
510
t1 = self.line_buffer
513
t1 = next(self.file).strip()
514
self.current_doc = self.current_doc+1
515
self.line_buffer = next(self.file).strip()
521
def get_random_line(self):
523
Get random line from another document for nextSentence task.
524
:return: str, content of one line
531
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
532
rand_doc = self.all_docs[rand_doc_idx]
533
line = rand_doc[random.randrange(len(rand_doc))]
535
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
537
for _ in range(rand_index):
538
line = self.get_next_line()
540
if self.current_random_doc != self.current_doc:
544
def get_next_line(self):
545
""" Gets next line of random_file and starts over when reaching end of file"""
547
line = next(self.random_file).strip()
550
self.current_random_doc = self.current_random_doc + 1
551
line = next(self.random_file).strip()
552
except StopIteration:
553
self.random_file.close()
554
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
555
line = next(self.random_file).strip()
559
class InputExample(object):
560
"""A single training/test example for the language model."""
562
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
563
"""Constructs a InputExample.
565
guid: Unique id for the example.
566
tokens_a: string. The untokenized text of the first sequence. For single
567
sequence tasks, only this sequence must be specified.
568
tokens_b: (Optional) string. The untokenized text of the second sequence.
569
Only must be specified for sequence pair tasks.
570
label: (Optional) string. The label of the example. This should be
571
specified for train and dev examples, but not for test examples.
574
self.tokens_a = tokens_a
575
self.tokens_b = tokens_b
576
self.is_next = is_next
577
self.lm_labels = lm_labels
580
class InputFeatures(object):
581
"""A single set of features of data."""
583
def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
584
self.input_ids = input_ids
585
self.input_ids_org = input_ids_org
586
self.input_mask = input_mask
587
self.segment_ids = segment_ids
588
self.is_next = is_next
589
self.lm_label_ids = lm_label_ids
590
self.tail_idxs = tail_idxs
593
def random_word(tokens, tokenizer):
595
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
596
:param tokens: list of str, tokenized sentence.
597
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
598
:return: (list of str, list of int), masked tokens and related labels for LM prediction
602
for i, token in enumerate(tokens):
604
prob = random.random()
621
candidate_id = random.randint(0,tokenizer.vocab_size)
622
w = tokenizer.convert_ids_to_tokens(candidate_id)
624
if tokens[i] == None:
626
w = tokenizer.convert_ids_to_tokens(candidate_id)
636
w = tokenizer.convert_tokens_to_ids(token)
638
output_label.append(w)
640
print("Have no this tokens in ids")
645
w = tokenizer.convert_tokens_to_ids("<unk>")
646
output_label.append(w)
647
logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
650
output_label.append(-1)
652
return tokens, output_label
655
def convert_example_to_features(example, max_seq_length, tokenizer):
657
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
658
IDs, LM labels, input_mask, CLS and SEP tokens etc.
659
:param example: InputExample, containing sentence input as strings and is_next label
660
:param max_seq_length: int, maximum length of sequence.
661
:param tokenizer: Tokenizer
662
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
665
tokens_a = example.tokens_a
666
tokens_b = example.tokens_b
671
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
674
tokens_a_org = tokens_a.copy()
675
tokens_a, t1_label = random_word(tokens_a, tokenizer)
685
lm_label_ids = ([-1] + t1_label + [-1])
709
tokens_org.append("<s>")
710
segment_ids.append(0)
711
for i, token in enumerate(tokens_a):
713
tokens.append(tokens_a[i])
714
tokens_org.append(tokens_a_org[i])
715
segment_ids.append(0)
718
tokens_org.append("s")
719
segment_ids.append(0)
720
tokens.append("</s>")
721
tokens_org.append("</s>")
722
segment_ids.append(0)
728
input_ids = tokenizer.encode(tokens, add_special_tokens=False)
729
input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
730
tail_idxs = len(input_ids)+1
733
input_ids = [w if w!=None else 0 for w in input_ids]
734
input_ids_org = [w if w!=None else 0 for w in input_ids_org]
740
input_mask = [1] * len(input_ids)
743
pad_id = tokenizer.convert_tokens_to_ids("<pad>")
744
while len(input_ids) < max_seq_length:
745
input_ids.append(pad_id)
746
input_ids_org.append(pad_id)
748
segment_ids.append(0)
749
lm_label_ids.append(-1)
752
assert len(input_ids) == max_seq_length
753
assert len(input_ids_org) == max_seq_length
754
assert len(input_mask) == max_seq_length
755
assert len(segment_ids) == max_seq_length
756
assert len(lm_label_ids) == max_seq_length
760
logger.info("*** Example ***")
761
logger.info("guid: %s" % (example.guid))
762
logger.info("tokens: %s" % " ".join(
763
[str(x) for x in tokens]))
764
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
765
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
767
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
768
logger.info("LM label: %s " % (lm_label_ids))
769
logger.info("Is next sentence label: %s " % (example.is_next))
772
features = InputFeatures(input_ids=input_ids,
773
input_ids_org = input_ids_org,
774
input_mask=input_mask,
775
segment_ids=segment_ids,
776
lm_label_ids=lm_label_ids,
777
is_next=example.is_next,
783
parser = argparse.ArgumentParser()
786
parser.add_argument("--data_dir_indomain",
790
help="The input train corpus.(In Domain)")
791
parser.add_argument("--data_dir_outdomain",
795
help="The input train corpus.(Out Domain)")
796
parser.add_argument("--pretrain_model", default=None, type=str, required=True,
797
help="Bert pre-trained model selected in the list: bert-base-uncased, "
798
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
799
parser.add_argument("--output_dir",
803
help="The output directory where the model checkpoints will be written.")
804
parser.add_argument("--augment_times",
808
help="Default batch_size/augment_times to save model")
810
parser.add_argument("--max_seq_length",
813
help="The maximum total input sequence length after WordPiece tokenization. \n"
814
"Sequences longer than this will be truncated, and sequences shorter \n"
815
"than this will be padded.")
816
parser.add_argument("--do_train",
818
help="Whether to run training.")
819
parser.add_argument("--train_batch_size",
822
help="Total batch size for training.")
823
parser.add_argument("--learning_rate",
826
help="The initial learning rate for Adam.")
827
parser.add_argument("--num_train_epochs",
830
help="Total number of training epochs to perform.")
831
parser.add_argument("--warmup_proportion",
834
help="Proportion of training to perform linear learning rate warmup for. "
835
"E.g., 0.1 = 10%% of training.")
836
parser.add_argument("--no_cuda",
838
help="Whether not to use CUDA when available")
839
parser.add_argument("--on_memory",
841
help="Whether to load train samples into memory or use disk")
842
parser.add_argument("--do_lower_case",
844
help="Whether to lower case the input text. True for uncased models, False for cased models.")
845
parser.add_argument("--local_rank",
848
help="local_rank for distributed training on gpus")
849
parser.add_argument('--seed',
852
help="random seed for initialization")
853
parser.add_argument('--gradient_accumulation_steps',
856
help="Number of updates steps to accumualte before performing a backward/update pass.")
857
parser.add_argument('--fp16',
859
help="Whether to use 16-bit float precision instead of 32-bit")
860
parser.add_argument('--loss_scale',
861
type = float, default = 0,
862
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
863
"0 (default value): dynamic loss scaling.\n"
864
"Positive power of 2: static loss scaling value.\n")
866
parser.add_argument("--num_labels_task",
867
default=None, type=int,
869
help="num_labels_task")
870
parser.add_argument("--weight_decay",
873
help="Weight decay if we apply some.")
874
parser.add_argument("--adam_epsilon",
877
help="Epsilon for Adam optimizer.")
878
parser.add_argument("--max_grad_norm",
881
help="Max gradient norm.")
882
parser.add_argument('--fp16_opt_level',
885
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
886
"See details at https://nvidia.github.io/apex/amp.html")
887
parser.add_argument("--task",
894
args = parser.parse_args()
896
if args.local_rank == -1 or args.no_cuda:
897
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
898
n_gpu = torch.cuda.device_count()
900
torch.cuda.set_device(args.local_rank)
901
device = torch.device("cuda", args.local_rank)
904
torch.distributed.init_process_group(backend='nccl')
905
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
906
device, n_gpu, bool(args.local_rank != -1), args.fp16))
908
if args.gradient_accumulation_steps < 1:
909
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
910
args.gradient_accumulation_steps))
912
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
914
random.seed(args.seed)
915
np.random.seed(args.seed)
916
torch.manual_seed(args.seed)
918
torch.cuda.manual_seed_all(args.seed)
920
if not args.do_train:
921
raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
923
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
924
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
925
if not os.path.exists(args.output_dir):
926
os.makedirs(args.output_dir)
929
tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
933
num_train_optimization_steps = None
935
print("Loading Train Dataset", args.data_dir_indomain)
937
train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
938
num_train_optimization_steps = int(
939
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
940
if args.local_rank != -1:
941
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
946
model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
954
param_optimizer = list(model.named_parameters())
956
for par in param_optimizer:
960
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
961
optimizer_grouped_parameters = [
962
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
963
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
965
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
966
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
972
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
975
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
979
model = torch.nn.DataParallel(model)
981
if args.local_rank != -1:
982
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
987
logger.info("***** Running training *****")
988
logger.info(" Num examples = %d", len(train_dataset))
989
logger.info(" Batch size = %d", args.train_batch_size)
990
logger.info(" Num steps = %d", num_train_optimization_steps)
992
if args.local_rank == -1:
993
train_sampler = RandomSampler(train_dataset)
997
train_sampler = DistributedSampler(train_dataset)
998
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1000
output_loss_file = os.path.join(args.output_dir, "loss")
1001
loss_fout = open(output_loss_file, 'w')
1006
alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1007
for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1009
nb_tr_examples, nb_tr_steps = 0, 0
1010
for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1011
batch_ = tuple(t.to(device) for t in batch_)
1014
input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = batch_
1020
in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1024
query_domain = in_domain_rep.float().to("cpu")
1025
query_task = in_task_rep.float().to("cpu")
1027
results_domain = torch.matmul(query_domain,docs[-1,:,:].T)
1028
results_task = torch.matmul(query_task,docs[-1,:,:].T)
1033
results_domain = torch.matmul(docs, query_domain.transpose(0,1))
1035
results_domain = results_domain.transpose(1,2).transpose(0,1).sum(2)
1036
results_task = torch.matmul(docs, query_task.transpose(0,1))
1038
results_task = results_task.transpose(1,2).transpose(0,1).sum(2)
1040
print("Time:", (end-start)/60)
1044
results = results_domain + results_task
1050
bottom_k = torch.topk(results, k, dim=1, largest=False, sorted=True)
1051
batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1052
batch = tuple(t.to(device) for t in batch)
1053
#Only need input_ids
1054
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs = batch
1055
#domain_loss = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_class", in_domain_rep=in_domain_rep.to(device))
1058
domain_loss = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_class", in_domain_rep=in_domain_rep.to(device))
1063
top_k = torch.topk(results, k, dim=1, largest=True, sorted=True)
1064
batch_ = tuple(t.to("cpu") for t in batch_)
1065
batch = AugmentationData_Task(top_k, tokenizer, args.max_seq_length, add_org=batch_)
1066
batch = tuple(t.to(device) for t in batch)
1067
input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label = batch
1069
task_loss, mlm_loss = model(input_ids=input_ids, input_ids_org=input_ids_org, sentence_label=sentence_label, lm_label=lm_label_ids, attention_mask=input_mask, func="task_class and mlm", in_domain_rep=in_domain_rep.to(device), batch_size=args.train_batch_size)
1071
loss = domain_loss + task_loss + mlm_loss
1076
task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentence_label_, attention_mask=input_mask_, func="task_class")
1078
task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentence_label, attention_mask=input_mask, func="task_class")
1081
loss = task_loss_org + (task_loss_query*alpha*epo*step)/k
1092
if args.gradient_accumulation_steps > 1:
1093
loss = loss / args.gradient_accumulation_steps
1096
with amp.scale_loss(loss, optimizer) as scaled_loss:
1097
scaled_loss.backward()
1102
loss_fout.write("{}\n".format(loss.item()))
1105
tr_loss += loss.item()
1107
nb_tr_examples += input_ids_.size(0)
1109
if (step + 1) % args.gradient_accumulation_steps == 0:
1114
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1117
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1130
model_to_save = model.module if hasattr(model, 'module') else model
1131
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1132
torch.save(model_to_save.state_dict(), output_model_file)
1135
#if args.num_train_epochs/args.augment_times in [1,2,3]:
1136
if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1137
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
1138
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1139
torch.save(model_to_save.state_dict(), output_model_file)
1146
logger.info("** ** * Saving fine - tuned model ** ** * ")
1147
model_to_save = model.module if hasattr(model, 'module') else model
1148
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1150
torch.save(model_to_save.state_dict(), output_model_file)
1154
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1155
"""Truncates a sequence pair in place to the maximum length."""
1163
total_length = len(tokens_a)
1164
if total_length <= max_length:
1170
def accuracy(out, labels):
1171
outputs = np.argmax(out, axis=1)
1172
return np.sum(outputs == labels)
1175
if __name__ == "__main__":