CSS-LM

Форк
0
/
init_scl_t.py 
1427 строк · 58.1 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25
import json
26
import time
27
from torch.autograd import Variable
28
import torch.nn.functional as F
29

30
import numpy as np
31
import torch
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
36

37
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
38
#from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
39
from transformers.modeling_roberta_updateRep_self import RobertaForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
41

42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
43
                    datefmt='%m/%d/%Y %H:%M:%S',
44
                    level=logging.INFO)
45
logger = logging.getLogger(__name__)
46

47

48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
49

50

51
def get_parameter(parser):
52

53
    ## Required parameters
54
    parser.add_argument("--data_dir_indomain",
55
                        default=None,
56
                        type=str,
57
                        required=True,
58
                        help="The input train corpus.(In Domain)")
59
    parser.add_argument("--data_dir_outdomain",
60
                        default=None,
61
                        type=str,
62
                        required=True,
63
                        help="The input train corpus.(Out Domain)")
64
    parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
    parser.add_argument("--output_dir",
68
                        default=None,
69
                        type=str,
70
                        required=True,
71
                        help="The output directory where the model checkpoints will be written.")
72
    parser.add_argument("--augment_times",
73
                        default=None,
74
                        type=int,
75
                        required=True,
76
                        help="Default batch_size/augment_times to save model")
77
    ## Other parameters
78
    parser.add_argument("--max_seq_length",
79
                        default=128,
80
                        type=int,
81
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
82
                             "Sequences longer than this will be truncated, and sequences shorter \n"
83
                             "than this will be padded.")
84
    parser.add_argument("--do_train",
85
                        action='store_true',
86
                        help="Whether to run training.")
87
    parser.add_argument("--train_batch_size",
88
                        default=32,
89
                        type=int,
90
                        help="Total batch size for training.")
91
    parser.add_argument("--learning_rate",
92
                        default=3e-5,
93
                        type=float,
94
                        help="The initial learning rate for Adam.")
95
    parser.add_argument("--num_train_epochs",
96
                        default=3.0,
97
                        type=float,
98
                        help="Total number of training epochs to perform.")
99
    parser.add_argument("--warmup_proportion",
100
                        default=0.1,
101
                        type=float,
102
                        help="Proportion of training to perform linear learning rate warmup for. "
103
                             "E.g., 0.1 = 10%% of training.")
104
    parser.add_argument("--no_cuda",
105
                        action='store_true',
106
                        help="Whether not to use CUDA when available")
107
    parser.add_argument("--on_memory",
108
                        action='store_true',
109
                        help="Whether to load train samples into memory or use disk")
110
    parser.add_argument("--do_lower_case",
111
                        action='store_true',
112
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
    parser.add_argument("--local_rank",
114
                        type=int,
115
                        default=-1,
116
                        help="local_rank for distributed training on gpus")
117
    parser.add_argument('--seed',
118
                        type=int,
119
                        default=42,
120
                        help="random seed for initialization")
121
    parser.add_argument('--gradient_accumulation_steps',
122
                        type=int,
123
                        default=1,
124
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
125
    parser.add_argument('--fp16',
126
                        action='store_true',
127
                        help="Whether to use 16-bit float precision instead of 32-bit")
128
    parser.add_argument('--loss_scale',
129
                        type = float, default = 0,
130
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
                        "0 (default value): dynamic loss scaling.\n"
132
                        "Positive power of 2: static loss scaling value.\n")
133
    ####
134
    parser.add_argument("--num_labels_task",
135
                        default=None, type=int,
136
                        required=True,
137
                        help="num_labels_task")
138
    parser.add_argument("--weight_decay",
139
                        default=0.0,
140
                        type=float,
141
                        help="Weight decay if we apply some.")
142
    parser.add_argument("--adam_epsilon",
143
                        default=1e-8,
144
                        type=float,
145
                        help="Epsilon for Adam optimizer.")
146
    parser.add_argument("--max_grad_norm",
147
                        default=1.0,
148
                        type=float,
149
                        help="Max gradient norm.")
150
    parser.add_argument('--fp16_opt_level',
151
                        type=str,
152
                        default='O1',
153
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
                             "See details at https://nvidia.github.io/apex/amp.html")
155
    parser.add_argument("--task",
156
                        default=0,
157
                        type=int,
158
                        required=True,
159
                        help="Choose Task")
160
    parser.add_argument("--K",
161
                        default=None,
162
                        type=int,
163
                        required=True,
164
                        help="K size")
165
    ####
166
    return parser
167

168

169
def return_Classifier(weight, bias, dim_in, dim_out):
170
    #LeakyReLU = torch.nn.LeakyReLU
171
    classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
172
    #print(classifier)
173
    #print(classifier.weight)
174
    #print(classifier.weight.shape)
175
    #print(classifier.weight.data)
176
    #print(classifier.weight.data.shape)
177
    #print("---")
178
    classifier.weight.data = weight.to("cpu")
179
    classifier.bias.data = bias.to("cpu")
180
    classifier.requires_grad=False
181
    #print(classifier)
182
    #print(classifier.weight)
183
    #print(classifier.weight.shape)
184
    #print("---")
185
    #exit()
186
    #print(classifier)
187
    #exit()
188
    #logit = LeakyReLU(classifier)
189
    return classifier
190

191

192
def load_GeneralDomain(dir_data_out):
193

194
    ###
195
    print("===========")
196
    print("Load CLS.pt and train.json")
197
    print("-----------")
198
    docs_head = torch.load(dir_data_out+"train_head.pt")
199
    docs_tail = torch.load(dir_data_out+"train_tail.pt")
200
    print("CLS.pt Done")
201
    print(docs_head.shape)
202
    print(docs_tail.shape)
203
    print("-----------")
204
    with open(dir_data_out+"train.json") as file:
205
        data = json.load(file)
206
    print("train.json Done")
207
    print("===========")
208
    docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
    return docs_tail_head, docs_head, docs_tail, data
210
    ###
211

212
def load_InDomain(dir_data_in):
213
    in_data=dict()
214
    with open(dir_data_in+"train.json") as file:
215
        data = json.load(file)
216
    for id, line in enumerate(data):
217
        in_data[id]=line["sentence"]
218
    return in_data
219

220
parser = argparse.ArgumentParser()
221
parser = get_parameter(parser)
222
args = parser.parse_args()
223

224

225
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
226
in_data = load_InDomain(args.data_dir_indomain)
227
######
228
if docs_head.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
229
    #last <s>
230
    #docs = docs[:,0,:].unsqueeze(1)
231
    #mean 13 layers <s>
232
    docs_head = docs_head.mean(1).unsqueeze(1)
233
    print(docs_head.shape)
234
else:
235
    print(docs_head.shape)
236
if docs_tail.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
237
    #last <s>
238
    #docs = docs[:,0,:].unsqueeze(1)
239
    #mean 13 layers <s>
240
    docs_tail = docs_tail.mean(1).unsqueeze(1)
241
    print(docs_tail.shape)
242
else:
243
    print(docs_tail.shape)
244
######
245

246
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
247
    ###Open
248
    with open(data_dir_indomain+"train.json") as file:
249
        data = json.load(file)
250

251
    ###Preprocess
252
    num_label_list = list()
253
    label_sentence_dict = dict()
254
    num_sentiment_label_list = list()
255
    sentiment_label_dict = dict()
256
    for line in data:
257
        #line["sentence"]
258
        #line["aspect"]
259
        #line["sentiment"]
260
        num_sentiment_label_list.append(line["sentiment"])
261
        #num_label_list.append(line["aspect"])
262
        num_label_list.append(line["sentiment"])
263

264
    num_label = sorted(list(set(num_label_list)))
265
    label_map = {label : i for i , label in enumerate(num_label)}
266
    num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
267
    sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
268
    print("=======")
269
    print("label_map:")
270
    print(label_map)
271
    print("=======")
272
    print("=======")
273
    print("sentiment_label_map:")
274
    print(sentiment_label_map)
275
    print("=======")
276

277
    ###Create data: 1 choosed data along with the rest of 7 class data
278

279
    '''
280
    all_input_ids = list()
281
    all_input_mask = list()
282
    all_segment_ids = list()
283
    all_lm_labels_ids = list()
284
    all_is_next = list()
285
    all_tail_idxs = list()
286
    all_sentence_labels = list()
287
    '''
288
    cur_tensors_list = list()
289
    #print(list(label_map.values()))
290
    candidate_label_list = list(label_map.values())
291
    candidate_sentiment_label_list = list(sentiment_label_map.values())
292
    all_type_sentence = [0]*len(candidate_label_list)
293
    all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
294
    for line in data:
295
        #line["sentence"]
296
        #line["aspect"]
297
        sentiment = line["sentiment"]
298
        sentence = line["sentence"]
299
        #label = line["aspect"]
300
        label = line["sentiment"]
301

302

303
        tokens_a = tokenizer.tokenize(sentence)
304
        #input_ids = tokenizer.encode(sentence, add_special_tokens=False)
305
        '''
306
        if "</s>" in tokens_a:
307
            print("Have more than 1 </s>")
308
            #tokens_a[tokens_a.index("<s>")] = "s"
309
            for i in range(len(tokens_a)):
310
                if tokens_a[i] == "</s>":
311
                    tokens_a[i] == "s"
312
        '''
313

314

315
        # tokenize
316
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
317
        # transform sample to features
318
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
319

320
        cur_tensors = (torch.tensor(cur_features.input_ids),
321
                       torch.tensor(cur_features.input_ids_org),
322
                       torch.tensor(cur_features.input_mask),
323
                       torch.tensor(cur_features.segment_ids),
324
                       torch.tensor(cur_features.lm_label_ids),
325
                       torch.tensor(0),
326
                       torch.tensor(cur_features.tail_idxs),
327
                       torch.tensor(label_map[label]),
328
                       torch.tensor(sentiment_label_map[sentiment]))
329

330
        cur_tensors_list.append(cur_tensors)
331

332
        ###
333
        if label_map[label] in candidate_label_list:
334
            all_type_sentence[label_map[label]]=cur_tensors
335
            candidate_label_list.remove(label_map[label])
336

337
        if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
338
            #print("----")
339
            #print(sentiment_label_map[sentiment])
340
            #print("----")
341
            all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
342
            candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
343
        ###
344

345

346

347

348
    return all_type_sentiment_sentence, cur_tensors_list
349

350

351

352
def AugmentationData_Domain(train_batch_size, k, tokenizer, max_seq_length):
353
    #top_k_shape = top_k.indices.shape
354
    #ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
355
    #top_k_shape = top_k["indices"].shape
356
    #ids_pos = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
357
    #ids = top_k["indices"]
358

359
    #bottom_k_shape = bottom_k["indices"].shape
360
    #ids_neg = bottom_k["indices"].reshape(bottom_k_shape[0]*bottom_k_shape[1]).tolist()
361
    ids_pos = random.sample(range(0,len(in_data)),train_batch_size)
362

363
    ids_neg = random.sample(range(0,len(data)),train_batch_size*k)
364

365
    #ids = ids_pos+ids_neg
366

367

368
    all_input_ids = list()
369
    all_input_ids_org = list()
370
    all_input_mask = list()
371
    all_segment_ids = list()
372
    all_lm_labels_ids = list()
373
    all_is_next = list()
374
    all_tail_idxs = list()
375
    all_id_domain = list()
376

377
    for id, i in enumerate(ids_pos):
378
        t1 = data[str(i)]['sentence']
379

380
        #tokens_a = tokenizer.tokenize(t1)
381
        tokens_a = tokenizer.tokenize(t1)
382
        # tokenize
383
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
384

385
        # transform sample to features
386
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
387

388
        all_input_ids.append(torch.tensor(cur_features.input_ids))
389
        all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
390
        all_input_mask.append(torch.tensor(cur_features.input_mask))
391
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
392
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
393
        all_is_next.append(torch.tensor(0))
394
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
395
        #if i in ids_neg:
396
        #    all_id_domain.append(torch.tensor(0))
397
        #elif i in ids_pos:
398
        #    all_id_domain.append(torch.tensor(1))
399
        all_id_domain.append(torch.tensor(1))
400

401

402
    for id, i in enumerate(ids_neg):
403
        t1 = data[str(i)]['sentence']
404

405
        #tokens_a = tokenizer.tokenize(t1)
406
        tokens_a = tokenizer.tokenize(t1)
407
        # tokenize
408
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
409

410
        # transform sample to features
411
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
412

413
        all_input_ids.append(torch.tensor(cur_features.input_ids))
414
        all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
415
        all_input_mask.append(torch.tensor(cur_features.input_mask))
416
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
417
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
418
        all_is_next.append(torch.tensor(0))
419
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
420
        #if i in ids_neg:
421
        #    all_id_domain.append(torch.tensor(0))
422
        #elif i in ids_pos:
423
        #    all_id_domain.append(torch.tensor(1))
424
        all_id_domain.append(torch.tensor(0))
425

426

427

428
    cur_tensors = (torch.stack(all_input_ids),
429
                   torch.stack(all_input_ids_org),
430
                   torch.stack(all_input_mask),
431
                   torch.stack(all_segment_ids),
432
                   torch.stack(all_lm_labels_ids),
433
                   torch.stack(all_is_next),
434
                   torch.stack(all_tail_idxs),
435
                   torch.stack(all_id_domain))
436

437
    return cur_tensors
438

439

440
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
441
    top_k_shape = top_k["indices"].shape
442
    sentence_ids = top_k["indices"]
443

444
    all_input_ids = list()
445
    all_input_ids_org = list()
446
    all_input_mask = list()
447
    all_segment_ids = list()
448
    all_lm_labels_ids = list()
449
    all_is_next = list()
450
    all_tail_idxs = list()
451
    all_sentence_labels = list()
452
    all_sentiment_labels = list()
453

454
    add_org = tuple(t.to('cpu') for t in add_org)
455
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
456
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
457

458
    ###
459
    #print("input_ids_",input_ids_.shape)
460
    #print("---")
461
    #print("sentence_ids",sentence_ids.shape)
462
    #print("---")
463
    #print("sentence_label_",sentence_label_.shape)
464
    #exit()
465

466

467
    for id_1, sent in enumerate(sentence_ids):
468
        for id_2, sent_id in enumerate(sent):
469

470
            t1 = data[str(int(sent_id))]['sentence']
471

472
            tokens_a = tokenizer.tokenize(t1)
473

474
            # tokenize
475
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
476

477
            # transform sample to features
478
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
479

480
            all_input_ids.append(torch.tensor(cur_features.input_ids))
481
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
482
            all_input_mask.append(torch.tensor(cur_features.input_mask))
483
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
484
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
485
            all_is_next.append(torch.tensor(0))
486
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
487
            all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
488
            all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
489

490
        all_input_ids.append(input_ids_[id_1])
491
        all_input_ids_org.append(input_ids_org_[id_1])
492
        all_input_mask.append(input_mask_[id_1])
493
        all_segment_ids.append(segment_ids_[id_1])
494
        all_lm_labels_ids.append(lm_label_ids_[id_1])
495
        all_is_next.append(is_next_[id_1])
496
        all_tail_idxs.append(tail_idxs_[id_1])
497
        all_sentence_labels.append(sentence_label_[id_1])
498
        all_sentiment_labels.append(sentiment_label_[id_1])
499

500

501
    cur_tensors = (torch.stack(all_input_ids),
502
                   torch.stack(all_input_ids_org),
503
                   torch.stack(all_input_mask),
504
                   torch.stack(all_segment_ids),
505
                   torch.stack(all_lm_labels_ids),
506
                   torch.stack(all_is_next),
507
                   torch.stack(all_tail_idxs),
508
                   torch.stack(all_sentence_labels),
509
                   torch.stack(all_sentiment_labels)
510
                   )
511

512

513
    return cur_tensors
514

515

516

517

518

519
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
520
    '''
521
    top_k_shape = top_k.indices.shape
522
    sentence_ids = top_k.indices
523
    '''
524
    #top_k_shape = top_k["indices"].shape
525
    #sentence_ids = top_k["indices"]
526

527

528
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
529
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
530

531

532
    #uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
533

534
    all_sentence_binary_label = list()
535
    #all_in_task_rep_comb = list()
536
    all_in_rep_comb = list()
537

538
    for id_1, num in enumerate(sentence_label_):
539
        #print([sentence_label_==num])
540
        #print(type([sentence_label_==num]))
541
        sentence_label_int = (sentence_label_==num).to(torch.long)
542
        #print(sentence_label_int)
543
        #print(sentence_label_int.shape)
544
        #print(in_task_rep[id_1].shape)
545
        #print(in_task_rep.shape)
546
        #exit()
547
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
548
        in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
549
        #print(in_task_rep_append)
550
        #print(in_task_rep_append.shape)
551
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
552
        in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
553
        #print(in_task_rep_comb)
554
        #print(in_task_rep_comb.shape)
555
        #exit()
556
        #sentence_label_int = sentence_label_int.to(torch.float32)
557
        #print(sentence_label_int)
558
        #exit()
559
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
560
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
561
        #print(in_task_rep_comb.shape)
562
        #print(in_domain_rep_comb.shape)
563
        in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
564
        #print(in_rep.shape)
565
        #exit()
566
        all_sentence_binary_label.append(sentence_label_int)
567
        #all_in_task_rep_comb.append(in_task_rep_comb)
568
        all_in_rep_comb.append(in_rep_comb)
569
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
570
    #all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
571
    all_in_rep_comb = torch.stack(all_in_rep_comb)
572

573
    #cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
574
    cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
575

576
    return cur_tensors
577

578

579

580

581
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
582
    '''
583
    top_k_shape = top_k.indices.shape
584
    sentence_ids = top_k.indices
585
    '''
586
    #top_k_shape = top_k["indices"].shape
587
    #sentence_ids = top_k["indices"]
588

589

590
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
591
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
592

593

594
    #uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
595

596
    all_sentence_binary_label = list()
597
    all_in_task_rep_comb = list()
598

599
    for id_1, num in enumerate(sentence_label_):
600
        #print([sentence_label_==num])
601
        #print(type([sentence_label_==num]))
602
        sentence_label_int = (sentence_label_==num).to(torch.long)
603
        #print(sentence_label_int)
604
        #print(sentence_label_int.shape)
605
        #print(in_task_rep[id_1].shape)
606
        #print(in_task_rep.shape)
607
        #exit()
608
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
609
        #print(in_task_rep_append)
610
        #print(in_task_rep_append.shape)
611
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
612
        #print(in_task_rep_comb)
613
        #print(in_task_rep_comb.shape)
614
        #exit()
615
        #sentence_label_int = sentence_label_int.to(torch.float32)
616
        #print(sentence_label_int)
617
        #exit()
618
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
619
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
620
        all_sentence_binary_label.append(sentence_label_int)
621
        all_in_task_rep_comb.append(in_task_rep_comb)
622
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
623
    all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
624

625
    cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
626

627
    return cur_tensors
628

629

630

631

632
class Dataset_noNext(Dataset):
633
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
634

635
        self.vocab_size = tokenizer.vocab_size
636
        self.tokenizer = tokenizer
637
        self.seq_len = seq_len
638
        self.on_memory = on_memory
639
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
640
        self.corpus_path = corpus_path
641
        self.encoding = encoding
642
        self.current_doc = 0  # to avoid random sentence from same doc
643

644
        # for loading samples directly from file
645
        self.sample_counter = 0  # used to keep track of full epochs on file
646
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
647

648
        # for loading samples in memory
649
        self.current_random_doc = 0
650
        self.num_docs = 0
651
        self.sample_to_doc = [] # map sample index to doc and line
652

653
        # load samples into memory
654
        if on_memory:
655
            self.all_docs = []
656
            doc = []
657
            self.corpus_lines = 0
658
            with open(corpus_path, "r", encoding=encoding) as f:
659
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
660
                    line = line.strip()
661
                    if line == "":
662
                        self.all_docs.append(doc)
663
                        doc = []
664
                        #remove last added sample because there won't be a subsequent line anymore in the doc
665
                        self.sample_to_doc.pop()
666
                    else:
667
                        #store as one sample
668
                        sample = {"doc_id": len(self.all_docs),
669
                                  "line": len(doc)}
670
                        self.sample_to_doc.append(sample)
671
                        doc.append(line)
672
                        self.corpus_lines = self.corpus_lines + 1
673

674
            # if last row in file is not empty
675
            if self.all_docs[-1] != doc:
676
                self.all_docs.append(doc)
677
                self.sample_to_doc.pop()
678

679
            self.num_docs = len(self.all_docs)
680

681
        # load samples later lazily from disk
682
        else:
683
            if self.corpus_lines is None:
684
                with open(corpus_path, "r", encoding=encoding) as f:
685
                    self.corpus_lines = 0
686
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
687
                        if line.strip() == "":
688
                            self.num_docs += 1
689
                        else:
690
                            self.corpus_lines += 1
691

692
                    # if doc does not end with empty line
693
                    if line.strip() != "":
694
                        self.num_docs += 1
695

696
            self.file = open(corpus_path, "r", encoding=encoding)
697
            self.random_file = open(corpus_path, "r", encoding=encoding)
698

699
    def __len__(self):
700
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
701
        return self.corpus_lines - self.num_docs - 1
702

703
    def __getitem__(self, item):
704
        cur_id = self.sample_counter
705
        self.sample_counter += 1
706
        if not self.on_memory:
707
            # after one epoch we start again from beginning of file
708
            if cur_id != 0 and (cur_id % len(self) == 0):
709
                self.file.close()
710
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
711

712
        #t1, t2, is_next_label = self.random_sent(item)
713
        t1, is_next_label = self.random_sent(item)
714
        if is_next_label == None:
715
            is_next_label = 0
716

717

718
        #tokens_a = self.tokenizer.tokenize(t1)
719
        tokens_a = tokenizer.tokenize(t1)
720
        '''
721
        if "</s>" in tokens_a:
722
            print("Have more than 1 </s>")
723
            #tokens_a[tokens_a.index("<s>")] = "s"
724
            for i in range(len(tokens_a)):
725
                if tokens_a[i] == "</s>":
726
                    tokens_a[i] = "s"
727
        '''
728
        #tokens_b = self.tokenizer.tokenize(t2)
729

730
        # tokenize
731
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
732

733
        # transform sample to features
734
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
735

736
        cur_tensors = (torch.tensor(cur_features.input_ids),
737
                       torch.tensor(cur_features.input_ids_org),
738
                       torch.tensor(cur_features.input_mask),
739
                       torch.tensor(cur_features.segment_ids),
740
                       torch.tensor(cur_features.lm_label_ids),
741
                       torch.tensor(cur_features.is_next),
742
                       torch.tensor(cur_features.tail_idxs))
743

744
        return cur_tensors
745

746
    def random_sent(self, index):
747
        """
748
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
749
        from one doc. With 50% the second sentence will be a random one from another doc.
750
        :param index: int, index of sample.
751
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
752
        """
753
        t1, t2 = self.get_corpus_line(index)
754
        return t1, None
755

756
    def get_corpus_line(self, item):
757
        """
758
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
759
        :param item: int, index of sample.
760
        :return: (str, str), two subsequent sentences from corpus
761
        """
762
        t1 = ""
763
        t2 = ""
764
        assert item < self.corpus_lines
765
        if self.on_memory:
766
            sample = self.sample_to_doc[item]
767
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
768
            # used later to avoid random nextSentence from same doc
769
            self.current_doc = sample["doc_id"]
770
            return t1, t2
771
            #return t1
772
        else:
773
            if self.line_buffer is None:
774
                # read first non-empty line of file
775
                while t1 == "" :
776
                    t1 = next(self.file).strip()
777
            else:
778
                # use t2 from previous iteration as new t1
779
                t1 = self.line_buffer
780
                # skip empty rows that are used for separating documents and keep track of current doc id
781
                while t1 == "":
782
                    t1 = next(self.file).strip()
783
                    self.current_doc = self.current_doc+1
784
            self.line_buffer = next(self.file).strip()
785

786
        assert t1 != ""
787
        return t1, t2
788

789

790
    def get_random_line(self):
791
        """
792
        Get random line from another document for nextSentence task.
793
        :return: str, content of one line
794
        """
795
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
796
        # corpora. However, just to be careful, we try to make sure that
797
        # the random document is not the same as the document we're processing.
798
        for _ in range(10):
799
            if self.on_memory:
800
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
801
                rand_doc = self.all_docs[rand_doc_idx]
802
                line = rand_doc[random.randrange(len(rand_doc))]
803
            else:
804
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
805
                #pick random line
806
                for _ in range(rand_index):
807
                    line = self.get_next_line()
808
            #check if our picked random line is really from another doc like we want it to be
809
            if self.current_random_doc != self.current_doc:
810
                break
811
        return line
812

813
    def get_next_line(self):
814
        """ Gets next line of random_file and starts over when reaching end of file"""
815
        try:
816
            line = next(self.random_file).strip()
817
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
818
            if line == "":
819
                self.current_random_doc = self.current_random_doc + 1
820
                line = next(self.random_file).strip()
821
        except StopIteration:
822
            self.random_file.close()
823
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
824
            line = next(self.random_file).strip()
825
        return line
826

827

828
class InputExample(object):
829
    """A single training/test example for the language model."""
830

831
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
832
        """Constructs a InputExample.
833
        Args:
834
            guid: Unique id for the example.
835
            tokens_a: string. The untokenized text of the first sequence. For single
836
            sequence tasks, only this sequence must be specified.
837
            tokens_b: (Optional) string. The untokenized text of the second sequence.
838
            Only must be specified for sequence pair tasks.
839
            label: (Optional) string. The label of the example. This should be
840
            specified for train and dev examples, but not for test examples.
841
        """
842
        self.guid = guid
843
        self.tokens_a = tokens_a
844
        self.tokens_b = tokens_b
845
        self.is_next = is_next  # nextSentence
846
        self.lm_labels = lm_labels  # masked words for language model
847

848

849
class InputFeatures(object):
850
    """A single set of features of data."""
851

852
    def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
853
        self.input_ids = input_ids
854
        self.input_ids_org = input_ids_org
855
        self.input_mask = input_mask
856
        self.segment_ids = segment_ids
857
        self.is_next = is_next
858
        self.lm_label_ids = lm_label_ids
859
        self.tail_idxs = tail_idxs
860

861

862
def random_word(tokens, tokenizer):
863
    """
864
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
865
    :param tokens: list of str, tokenized sentence.
866
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
867
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
868
    """
869
    output_label = []
870

871
    for i, token in enumerate(tokens):
872

873
        prob = random.random()
874
        # mask token with 15% probability
875
        if prob < 0.15:
876
            prob /= 0.15
877
            #candidate_id = random.randint(0,tokenizer.vocab_size)
878
            #print(tokenizer.convert_ids_to_tokens(candidate_id))
879

880

881
            # 80% randomly change token to mask token
882
            if prob < 0.8:
883
                #tokens[i] = "[MASK]"
884
                tokens[i] = "<mask>"
885

886
            # 10% randomly change token to random token
887
            elif prob < 0.9:
888
                #tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
889
                #tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
890
                candidate_id = random.randint(0,tokenizer.vocab_size)
891
                w = tokenizer.convert_ids_to_tokens(candidate_id)
892
                '''
893
                if tokens[i] == None:
894
                    candidate_id = 100
895
                    w = tokenizer.convert_ids_to_tokens(candidate_id)
896
                '''
897
                tokens[i] = w
898

899

900
            # -> rest 10% randomly keep current token
901

902
            # append current token to output (we will predict these later)
903
            try:
904
                #output_label.append(tokenizer.vocab[token])
905
                w = tokenizer.convert_tokens_to_ids(token)
906
                if w!= None:
907
                    output_label.append(w)
908
                else:
909
                    print("Have no this tokens in ids")
910
                    exit()
911
            except KeyError:
912
                # For unknown words (should not occur with BPE vocab)
913
                #output_label.append(tokenizer.vocab["<unk>"])
914
                w = tokenizer.convert_tokens_to_ids("<unk>")
915
                output_label.append(w)
916
                logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
917
        else:
918
            # no masking token (will be ignored by loss function later)
919
            output_label.append(-1)
920

921
    return tokens, output_label
922

923

924
def convert_example_to_features(example, max_seq_length, tokenizer):
925
    """
926
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
927
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
928
    :param example: InputExample, containing sentence input as strings and is_next label
929
    :param max_seq_length: int, maximum length of sequence.
930
    :param tokenizer: Tokenizer
931
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
932
    """
933
    #now tokens_a is input_ids
934
    tokens_a = example.tokens_a
935
    tokens_b = example.tokens_b
936
    # Modifies `tokens_a` and `tokens_b` in place so that the total
937
    # length is less than the specified length.
938
    # Account for [CLS], [SEP], [SEP] with "- 3"
939
    #_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
940
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
941

942
    #print(tokens_a)
943
    tokens_a_org = tokens_a.copy()
944
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
945
    #print("----")
946
    #print(tokens_a)
947
    #print(tokens_a_org)
948
    #exit()
949
    #print(t1_label)
950
    #exit()
951
    #tokens_b, t2_label = random_word(tokens_b, tokenizer)
952
    # concatenate lm labels and account for CLS, SEP, SEP
953
    #lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
954
    lm_label_ids = ([-1] + t1_label + [-1])
955

956
    # The convention in BERT is:
957
    # (a) For sequence pairs:
958
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
959
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
960
    # (b) For single sequences:
961
    #  tokens:   [CLS] the dog is hairy . [SEP]
962
    #  type_ids: 0   0   0   0  0     0 0
963
    #
964
    # Where "type_ids" are used to indicate whether this is the first
965
    # sequence or the second sequence. The embedding vectors for `type=0` and
966
    # `type=1` were learned during pre-training and are added to the wordpiece
967
    # embedding vector (and position vector). This is not *strictly* necessary
968
    # since the [SEP] token unambigiously separates the sequences, but it makes
969
    # it easier for the model to learn the concept of sequences.
970
    #
971
    # For classification tasks, the first vector (corresponding to [CLS]) is
972
    # used as as the "sentence vector". Note that this only makes sense because
973
    # the entire model is fine-tuned.
974
    tokens = []
975
    tokens_org = []
976
    segment_ids = []
977
    tokens.append("<s>")
978
    tokens_org.append("<s>")
979
    segment_ids.append(0)
980
    for i, token in enumerate(tokens_a):
981
        if token!="</s>":
982
            tokens.append(tokens_a[i])
983
            tokens_org.append(tokens_a_org[i])
984
            segment_ids.append(0)
985
        else:
986
            tokens.append("s")
987
            tokens_org.append("s")
988
            segment_ids.append(0)
989
    tokens.append("</s>")
990
    tokens_org.append("</s>")
991
    segment_ids.append(0)
992

993
    #tokens.append("[SEP]")
994
    #segment_ids.append(1)
995

996
    #input_ids = tokenizer.convert_tokens_to_ids(tokens)
997
    input_ids = tokenizer.encode(tokens, add_special_tokens=False)
998
    input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
999
    tail_idxs = len(input_ids)-1
1000

1001
    #print(input_ids)
1002
    input_ids = [w if w!=None else 0 for w in input_ids]
1003
    input_ids_org = [w if w!=None else 0 for w in input_ids_org]
1004
    #print(input_ids)
1005
    #exit()
1006

1007
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
1008
    # tokens are attended to.
1009
    input_mask = [1] * len(input_ids)
1010

1011
    # Zero-pad up to the sequence length.
1012
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1013
    while len(input_ids) < max_seq_length:
1014
        input_ids.append(pad_id)
1015
        input_ids_org.append(pad_id)
1016
        input_mask.append(0)
1017
        segment_ids.append(0)
1018
        lm_label_ids.append(-1)
1019

1020
    try:
1021
        assert len(input_ids) == max_seq_length
1022
        assert len(input_ids_org) == max_seq_length
1023
        assert len(input_mask) == max_seq_length
1024
        assert len(segment_ids) == max_seq_length
1025
        assert len(lm_label_ids) == max_seq_length
1026
    except:
1027
        print("!!!Warning!!!")
1028
        input_ids = input_ids[:max_seq_length-1]
1029
        if 2 not in input_ids:
1030
            input_ids += [2]
1031
        else:
1032
            input_ids += [pad_id]
1033
        input_ids_org = input_ids_org[:max_seq_length-1]
1034
        if 2 not in input_ids_org:
1035
            input_ids_org += [2]
1036
        else:
1037
            input_ids_org += [pad_id]
1038
        input_mask = input_mask[:max_seq_length-1]+[0]
1039
        segment_ids = segment_ids[:max_seq_length-1]+[0]
1040
        lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1041
    '''
1042
    flag=False
1043
    if len(input_ids) != max_seq_length:
1044
        print(len(input_ids))
1045
        flag=True
1046
    if len(input_ids_org) != max_seq_length:
1047
        print(len(input_ids_org))
1048
        flag=True
1049
    if len(input_mask) != max_seq_length:
1050
        print(len(input_mask))
1051
        flag=True
1052
    if len(segment_ids) != max_seq_length:
1053
        print(len(segment_ids))
1054
        flag=True
1055
    if len(lm_label_ids) != max_seq_length:
1056
        print(len(lm_label_ids))
1057
        flag=True
1058
    if flag == True:
1059
        print("1165")
1060
        exit()
1061
    '''
1062

1063
    '''
1064
    if example.guid < 5:
1065
        logger.info("*** Example ***")
1066
        logger.info("guid: %s" % (example.guid))
1067
        logger.info("tokens: %s" % " ".join(
1068
                [str(x) for x in tokens]))
1069
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1070
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1071
        logger.info(
1072
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1073
        logger.info("LM label: %s " % (lm_label_ids))
1074
        logger.info("Is next sentence label: %s " % (example.is_next))
1075
    '''
1076

1077
    features = InputFeatures(input_ids=input_ids,
1078
                             input_ids_org = input_ids_org,
1079
                             input_mask=input_mask,
1080
                             segment_ids=segment_ids,
1081
                             lm_label_ids=lm_label_ids,
1082
                             is_next=example.is_next,
1083
                             tail_idxs=tail_idxs)
1084
    return features
1085

1086

1087
def main():
1088
    parser = argparse.ArgumentParser()
1089

1090
    parser = get_parameter(parser)
1091

1092
    args = parser.parse_args()
1093

1094
    if args.local_rank == -1 or args.no_cuda:
1095
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1096
        n_gpu = torch.cuda.device_count()
1097
    else:
1098
        torch.cuda.set_device(args.local_rank)
1099
        device = torch.device("cuda", args.local_rank)
1100
        n_gpu = 1
1101
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1102
        torch.distributed.init_process_group(backend='nccl')
1103
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1104
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
1105

1106
    if args.gradient_accumulation_steps < 1:
1107
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1108
                            args.gradient_accumulation_steps))
1109

1110
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1111

1112
    random.seed(args.seed)
1113
    np.random.seed(args.seed)
1114
    torch.manual_seed(args.seed)
1115
    if n_gpu > 0:
1116
        torch.cuda.manual_seed_all(args.seed)
1117

1118
    if not args.do_train:
1119
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1120

1121
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1122
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1123
    if not os.path.exists(args.output_dir):
1124
        os.makedirs(args.output_dir)
1125

1126
    #tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1127
    tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1128

1129

1130
    #train_examples = None
1131
    num_train_optimization_steps = None
1132
    if args.do_train:
1133
        print("Loading Train Dataset", args.data_dir_indomain)
1134
        #train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1135
        all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1136
        num_train_optimization_steps = int(
1137
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1138
        if args.local_rank != -1:
1139
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1140

1141

1142

1143
    # Prepare model
1144
    model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1145
    #model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1146
    model.to(device)
1147

1148

1149

1150
    # Prepare optimizer
1151
    if args.do_train:
1152
        param_optimizer = list(model.named_parameters())
1153
        '''
1154
        for par in param_optimizer:
1155
            print(par[0])
1156
        exit()
1157
        '''
1158
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1159
        optimizer_grouped_parameters = [
1160
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1161
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1162
            ]
1163
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1164
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1165

1166
        if args.fp16:
1167
            try:
1168
                from apex import amp
1169
            except ImportError:
1170
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1171
                exit()
1172

1173
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1174

1175

1176
        if n_gpu > 1:
1177
            model = torch.nn.DataParallel(model)
1178

1179
        if args.local_rank != -1:
1180
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1181

1182

1183

1184
    global_step = 0
1185
    if args.do_train:
1186
        logger.info("***** Running training *****")
1187
        logger.info("  Num examples = %d", len(train_dataset))
1188
        logger.info("  Batch size = %d", args.train_batch_size)
1189
        logger.info("  Num steps = %d", num_train_optimization_steps)
1190

1191
        if args.local_rank == -1:
1192
            train_sampler = RandomSampler(train_dataset)
1193
            #all_type_sentence_sampler = RandomSampler(all_type_sentence)
1194
        else:
1195
            #TODO: check if this works with current data generator from disk that relies on next(file)
1196
            # (it doesn't return item back by index)
1197
            train_sampler = DistributedSampler(train_dataset)
1198
            #all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1199
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1200
        #all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1201

1202
        output_loss_file = os.path.join(args.output_dir, "loss")
1203
        loss_fout = open(output_loss_file, 'w')
1204

1205

1206
        output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1207
        loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1208
        model.train()
1209

1210

1211

1212

1213
        #alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1214
        #alpha = float(1/args.num_train_epochs)
1215
        alpha = float(1)
1216
        #k=8
1217
        k=10
1218
        #k = args.K
1219
        #k = 10
1220
        #k = 2
1221
        #retrive_gate = args.num_labels_task
1222
        #retrive_gate = len(train_dataset)/100
1223
        retrive_gate = 1
1224
        all_type_sentence_label = list()
1225
        all_previous_sentence_label = list()
1226
        all_type_sentiment_label = list()
1227
        all_previous_sentiment_label = list()
1228
        top_k_all_type = dict()
1229
        bottom_k_all_type = dict()
1230
        for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1231
            tr_loss = 0
1232
            nb_tr_examples, nb_tr_steps = 0, 0
1233
            for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1234

1235

1236
                #######################
1237
                ######################
1238

1239
                ###Normal mode
1240
                batch_ = tuple(t.to(device) for t in batch_)
1241
                input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1242

1243

1244
                ###
1245
                # Generate query representation
1246
                in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1247

1248

1249
                #################
1250
                #################
1251
                '''
1252
                #Domain Binary Classifier - Outdomain
1253
                #batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1254
                batch = AugmentationData_Domain(in_domain_rep.shape[0], k, tokenizer, args.max_seq_length)
1255
                batch = tuple(t.to(device) for t in batch)
1256
                input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1257

1258
                out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1259
                #print("======")
1260
                #print(domain_top_k["indices"].shape)
1261
                #print(input_ids_org.shape)
1262
                #print(out_domain_rep_tail.shape)
1263
                #print(in_domain_rep.shape)
1264
                #print("======")
1265
                ############Construct constrive instances
1266
                #print("=============")
1267
                #print(out_domain_rep_tail.shape)
1268
                #print(in_domain_rep.shape)
1269
                #print("=============")
1270
                comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1271
                in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1272
                comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1273

1274
                in_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id,use_detach=True)
1275
                '''
1276
                ############
1277

1278

1279
                #################
1280
                #################
1281
                #Task Binary Classifier    in domain
1282
                #Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1283
                #batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1284
                batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep)
1285
                batch = tuple(t.to(device) for t in batch)
1286
                all_in_task_rep_comb, all_sentence_binary_label = batch
1287
                in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1288

1289

1290
                #################
1291
                #################
1292
                #Train Task org - finetune
1293
                #split into: in_dom and query_  --> different weight
1294
                task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1295

1296
                #################
1297
                #################
1298
                #Domain Task binary   including outdomain
1299
                #batch = AugmentationData_Task(task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1300
                #batch = tuple(t.to(device) for t in batch)
1301
                #input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1302
                #out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1303
                ###
1304
                #batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1305
                #batch = tuple(t.to(device) for t in batch)
1306
                #all_in_task_rep_comb, all_sentence_binary_label = batch
1307
                #out_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1308
                ###
1309

1310

1311
                #################
1312
                #################
1313
                #Domain-Task binary Level  (in domain task)
1314
                ###
1315
                '''
1316
                batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1317
                batch = tuple(t.to(device) for t in batch)
1318
                all_in_task_rep_comb, all_sentence_binary_label = batch
1319
                in_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1320
                '''
1321
                ###
1322

1323

1324
                #################
1325
                #################
1326

1327
                if n_gpu > 1:
1328
                    #loss = task_loss_org.mean()*2 + in_task_binary_loss.mean() + in_domain_task_binary_loss.mean()*0.5
1329
                    loss = task_loss_org.mean() + in_task_binary_loss.mean()
1330
                    #loss = task_loss_org.mean() + in_domain_task_binary_loss.mean()
1331
                    #loss = task_loss_org.mean()
1332
                else:
1333
                    #loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1334
                    print("No Using GPU")
1335

1336

1337
                if args.gradient_accumulation_steps > 1:
1338
                    loss = loss / args.gradient_accumulation_steps
1339
                if args.fp16:
1340
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
1341
                        scaled_loss.backward()
1342
                else:
1343
                    loss.backward()
1344

1345
                ###
1346
                loss_fout.write("{}\n".format(loss.item()))
1347
                ###
1348

1349
                ###
1350
                #loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1351
                ###
1352

1353
                tr_loss += loss.item()
1354
                #nb_tr_examples += input_ids.size(0)
1355
                nb_tr_examples += input_ids_.size(0)
1356
                nb_tr_steps += 1
1357
                if (step + 1) % args.gradient_accumulation_steps == 0:
1358
                    if args.fp16:
1359
                        # modify learning rate with special warm up BERT uses
1360
                        # if args.fp16 is False, BertAdam is used that handles this automatically
1361
                        #lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1362
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1363
                    ###
1364
                    else:
1365
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1366
                    ###
1367

1368
                    optimizer.step()
1369
                    ###
1370
                    scheduler.step()
1371
                    ###
1372
                    #optimizer.zero_grad()
1373
                    model.zero_grad()
1374
                    global_step += 1
1375

1376

1377
            if epo < -1:
1378
                continue
1379
            else:
1380
                model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1381
                #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1382
                output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(epo))
1383
                torch.save(model_to_save.state_dict(), output_model_file)
1384
            ####
1385
            '''
1386
            #if args.num_train_epochs/args.augment_times in [1,2,3]:
1387
            if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1388
                model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1389
                output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1390
                torch.save(model_to_save.state_dict(), output_model_file)
1391
            '''
1392
            ####
1393

1394
        loss_fout.close()
1395

1396
        # Save a trained model
1397
        logger.info("** ** * Saving fine - tuned model ** ** * ")
1398
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1399
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1400
        if args.do_train:
1401
            torch.save(model_to_save.state_dict(), output_model_file)
1402

1403

1404

1405
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1406
    """Truncates a sequence pair in place to the maximum length."""
1407

1408
    # This is a simple heuristic which will always truncate the longer sequence
1409
    # one token at a time. This makes more sense than truncating an equal percent
1410
    # of tokens from each, since if one sequence is very short then each token
1411
    # that's truncated likely contains more information than a longer sequence.
1412
    while True:
1413
        #total_length = len(tokens_a) + len(tokens_b)
1414
        total_length = len(tokens_a)
1415
        if total_length <= max_length:
1416
            break
1417
        else:
1418
            tokens_a.pop()
1419

1420

1421
def accuracy(out, labels):
1422
    outputs = np.argmax(out, axis=1)
1423
    return np.sum(outputs == labels)
1424

1425

1426
if __name__ == "__main__":
1427
    main()
1428

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.