CSS-LM

Форк
0
/
self_bert_training.py 
1843 строки · 86.8 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25
import json
26
import time
27
from torch.autograd import Variable
28
import torch.nn.functional as F
29

30
import numpy as np
31
import torch
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
36

37
from transformers import BertTokenizer, BertForMaskedLM, BertForSequenceClassification
38
#from transformers.modeling_bert import BertForMaskedLMDomainTask
39
from transformers.modeling_bert_updateRep_self import BertForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
41

42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
43
                    datefmt='%m/%d/%Y %H:%M:%S',
44
                    level=logging.INFO)
45
logger = logging.getLogger(__name__)
46

47

48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
49

50

51
def get_parameter(parser):
52

53
    ## Required parameters
54
    parser.add_argument("--data_dir_indomain",
55
                        default=None,
56
                        type=str,
57
                        required=True,
58
                        help="The input train corpus.(In Domain)")
59
    parser.add_argument("--data_dir_outdomain",
60
                        default=None,
61
                        type=str,
62
                        required=True,
63
                        help="The input train corpus.(Out Domain)")
64
    parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
    parser.add_argument("--output_dir",
68
                        default=None,
69
                        type=str,
70
                        required=True,
71
                        help="The output directory where the model checkpoints will be written.")
72
    parser.add_argument("--augment_times",
73
                        default=None,
74
                        type=int,
75
                        required=True,
76
                        help="Default batch_size/augment_times to save model")
77
    ## Other parameters
78
    parser.add_argument("--max_seq_length",
79
                        default=128,
80
                        type=int,
81
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
82
                             "Sequences longer than this will be truncated, and sequences shorter \n"
83
                             "than this will be padded.")
84
    parser.add_argument("--do_train",
85
                        action='store_true',
86
                        help="Whether to run training.")
87
    parser.add_argument("--train_batch_size",
88
                        default=32,
89
                        type=int,
90
                        help="Total batch size for training.")
91
    parser.add_argument("--learning_rate",
92
                        default=3e-5,
93
                        type=float,
94
                        help="The initial learning rate for Adam.")
95
    parser.add_argument("--num_train_epochs",
96
                        default=3.0,
97
                        type=float,
98
                        help="Total number of training epochs to perform.")
99
    parser.add_argument("--warmup_proportion",
100
                        default=0.1,
101
                        type=float,
102
                        help="Proportion of training to perform linear learning rate warmup for. "
103
                             "E.g., 0.1 = 10%% of training.")
104
    parser.add_argument("--no_cuda",
105
                        action='store_true',
106
                        help="Whether not to use CUDA when available")
107
    parser.add_argument("--on_memory",
108
                        action='store_true',
109
                        help="Whether to load train samples into memory or use disk")
110
    parser.add_argument("--do_lower_case",
111
                        action='store_true',
112
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
    parser.add_argument("--local_rank",
114
                        type=int,
115
                        default=-1,
116
                        help="local_rank for distributed training on gpus")
117
    parser.add_argument('--seed',
118
                        type=int,
119
                        default=42,
120
                        help="random seed for initialization")
121
    parser.add_argument('--gradient_accumulation_steps',
122
                        type=int,
123
                        default=1,
124
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
125
    parser.add_argument('--fp16',
126
                        action='store_true',
127
                        help="Whether to use 16-bit float precision instead of 32-bit")
128
    parser.add_argument('--loss_scale',
129
                        type = float, default = 0,
130
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
                        "0 (default value): dynamic loss scaling.\n"
132
                        "Positive power of 2: static loss scaling value.\n")
133
    ####
134
    parser.add_argument("--num_labels_task",
135
                        default=None, type=int,
136
                        required=True,
137
                        help="num_labels_task")
138
    parser.add_argument("--weight_decay",
139
                        default=0.0,
140
                        type=float,
141
                        help="Weight decay if we apply some.")
142
    parser.add_argument("--adam_epsilon",
143
                        default=1e-8,
144
                        type=float,
145
                        help="Epsilon for Adam optimizer.")
146
    parser.add_argument("--max_grad_norm",
147
                        default=1.0,
148
                        type=float,
149
                        help="Max gradient norm.")
150
    parser.add_argument('--fp16_opt_level',
151
                        type=str,
152
                        default='O1',
153
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
                             "See details at https://nvidia.github.io/apex/amp.html")
155
    parser.add_argument("--task",
156
                        default=0,
157
                        type=int,
158
                        required=True,
159
                        help="Choose Task")
160
    parser.add_argument("--K",
161
                        default=None,
162
                        type=int,
163
                        required=True,
164
                        help="K size")
165
    ####
166
    return parser
167

168

169
def return_Classifier(weight, bias, dim_in, dim_out):
170
    #LeakyReLU = torch.nn.LeakyReLU
171
    classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
172
    #print(classifier)
173
    #print(classifier.weight)
174
    #print(classifier.weight.shape)
175
    #print(classifier.weight.data)
176
    #print(classifier.weight.data.shape)
177
    #print("---")
178
    classifier.weight.data = weight.to("cpu")
179
    classifier.bias.data = bias.to("cpu")
180
    classifier.requires_grad=False
181
    #print(classifier)
182
    #print(classifier.weight)
183
    #print(classifier.weight.shape)
184
    #print("---")
185
    #exit()
186
    #print(classifier)
187
    #exit()
188
    #logit = LeakyReLU(classifier)
189
    return classifier
190

191

192
def load_GeneralDomain(dir_data_out):
193

194
    ###
195
    print("===========")
196
    print("Load CLS.pt and train.json")
197
    print("-----------")
198
    docs_head = torch.load(dir_data_out+"train_head.pt")
199
    docs_tail = torch.load(dir_data_out+"train_tail.pt")
200
    print("CLS.pt Done")
201
    print(docs_head.shape)
202
    print(docs_tail.shape)
203
    print("-----------")
204
    with open(dir_data_out+"train.json") as file:
205
        data = json.load(file)
206
    print("train.json Done")
207
    print("===========")
208
    docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
    return docs_tail_head, docs_head, docs_tail, data
210
    ###
211

212

213
parser = argparse.ArgumentParser()
214
parser = get_parameter(parser)
215
args = parser.parse_args()
216
#print(args.data_dir_outdomain)
217
#exit()
218

219
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
220
######
221
if docs_head.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
222
    #last <s>
223
    #docs = docs[:,0,:].unsqueeze(1)
224
    #mean 13 layers <s>
225
    docs_head = docs_head.mean(1).unsqueeze(1)
226
    print(docs_head.shape)
227
else:
228
    print(docs_head.shape)
229
if docs_tail.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
230
    #last <s>
231
    #docs = docs[:,0,:].unsqueeze(1)
232
    #mean 13 layers <s>
233
    docs_tail = docs_tail.mean(1).unsqueeze(1)
234
    print(docs_tail.shape)
235
else:
236
    print(docs_tail.shape)
237
######
238

239
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
240
    ###Open
241
    with open(data_dir_indomain+"train.json") as file:
242
        data = json.load(file)
243

244
    ###Preprocess
245
    num_label_list = list()
246
    label_sentence_dict = dict()
247
    num_sentiment_label_list = list()
248
    sentiment_label_dict = dict()
249
    for line in data:
250
        #line["sentence"]
251
        #line["aspect"]
252
        #line["sentiment"]
253
        num_sentiment_label_list.append(line["sentiment"])
254
        #num_label_list.append(line["aspect"])
255
        num_label_list.append(line["sentiment"])
256

257
    num_label = sorted(list(set(num_label_list)))
258
    label_map = {label : i for i , label in enumerate(num_label)}
259
    num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
260
    sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
261
    print("=======")
262
    print("label_map:")
263
    print(label_map)
264
    print("=======")
265
    print("=======")
266
    print("sentiment_label_map:")
267
    print(sentiment_label_map)
268
    print("=======")
269

270
    ###Create data: 1 choosed data along with the rest of 7 class data
271

272
    '''
273
    all_input_ids = list()
274
    all_input_mask = list()
275
    all_segment_ids = list()
276
    all_lm_labels_ids = list()
277
    all_is_next = list()
278
    all_tail_idxs = list()
279
    all_sentence_labels = list()
280
    '''
281
    cur_tensors_list = list()
282
    #print(list(label_map.values()))
283
    candidate_label_list = list(label_map.values())
284
    candidate_sentiment_label_list = list(sentiment_label_map.values())
285
    all_type_sentence = [0]*len(candidate_label_list)
286
    all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
287
    for line in data:
288
        #line["sentence"]
289
        #line["aspect"]
290
        sentiment = line["sentiment"]
291
        sentence = line["sentence"]
292
        #label = line["aspect"]
293
        label = line["sentiment"]
294

295

296
        tokens_a = tokenizer.tokenize(sentence)
297
        #input_ids = tokenizer.encode(sentence, add_special_tokens=False)
298
        '''
299
        if "</s>" in tokens_a:
300
            print("Have more than 1 </s>")
301
            #tokens_a[tokens_a.index("<s>")] = "s"
302
            for i in range(len(tokens_a)):
303
                if tokens_a[i] == "</s>":
304
                    tokens_a[i] == "s"
305
        '''
306

307

308
        # tokenize
309
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
310
        # transform sample to features
311
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
312

313
        cur_tensors = (torch.tensor(cur_features.input_ids),
314
                       torch.tensor(cur_features.input_ids_org),
315
                       torch.tensor(cur_features.input_mask),
316
                       torch.tensor(cur_features.segment_ids),
317
                       torch.tensor(cur_features.lm_label_ids),
318
                       torch.tensor(0),
319
                       torch.tensor(cur_features.tail_idxs),
320
                       torch.tensor(label_map[label]),
321
                       torch.tensor(sentiment_label_map[sentiment]))
322

323
        cur_tensors_list.append(cur_tensors)
324

325
        ###
326
        if label_map[label] in candidate_label_list:
327
            all_type_sentence[label_map[label]]=cur_tensors
328
            candidate_label_list.remove(label_map[label])
329

330
        if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
331
            #print("----")
332
            #print(sentiment_label_map[sentiment])
333
            #print("----")
334
            all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
335
            candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
336
        ###
337

338

339

340

341
    return all_type_sentiment_sentence, cur_tensors_list
342

343

344

345
def AugmentationData_Domain(bottom_k, top_k, tokenizer, max_seq_length):
346
    #top_k_shape = top_k.indices.shape
347
    #ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
348
    top_k_shape = top_k["indices"].shape
349
    ids_pos = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
350
    #ids = top_k["indices"]
351

352
    bottom_k_shape = bottom_k["indices"].shape
353
    ids_neg = bottom_k["indices"].reshape(bottom_k_shape[0]*bottom_k_shape[1]).tolist()
354

355
    #print(ids_pos)
356
    #print(ids_neg)
357
    #exit()
358

359
    ids = ids_pos+ids_neg
360

361

362
    all_input_ids = list()
363
    all_input_ids_org = list()
364
    all_input_mask = list()
365
    all_segment_ids = list()
366
    all_lm_labels_ids = list()
367
    all_is_next = list()
368
    all_tail_idxs = list()
369
    all_id_domain = list()
370

371
    for id, i in enumerate(ids):
372
        t1 = data[str(i)]['sentence']
373

374
        #tokens_a = tokenizer.tokenize(t1)
375
        tokens_a = tokenizer.tokenize(t1)
376
        '''
377
        if "</s>" in tokens_a:
378
            print("Have more than 1 </s>")
379
            #tokens_a[tokens_a.index("<s>")] = "s"
380
            for i in range(len(tokens_a)):
381
                if tokens_a[i] == "</s>":
382
                    tokens_a[i] = "s"
383
        '''
384

385
        # tokenize
386
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
387

388
        # transform sample to features
389
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
390

391
        all_input_ids.append(torch.tensor(cur_features.input_ids))
392
        all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
393
        all_input_mask.append(torch.tensor(cur_features.input_mask))
394
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
395
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
396
        all_is_next.append(torch.tensor(0))
397
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
398
        if i in ids_neg:
399
            all_id_domain.append(torch.tensor(0))
400
        elif i in ids_pos:
401
            all_id_domain.append(torch.tensor(1))
402

403

404
    cur_tensors = (torch.stack(all_input_ids),
405
                   torch.stack(all_input_ids_org),
406
                   torch.stack(all_input_mask),
407
                   torch.stack(all_segment_ids),
408
                   torch.stack(all_lm_labels_ids),
409
                   torch.stack(all_is_next),
410
                   torch.stack(all_tail_idxs),
411
                   torch.stack(all_id_domain))
412

413
    return cur_tensors
414

415

416
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
417
    top_k_shape = top_k["indices"].shape
418
    sentence_ids = top_k["indices"]
419

420
    all_input_ids = list()
421
    all_input_ids_org = list()
422
    all_input_mask = list()
423
    all_segment_ids = list()
424
    all_lm_labels_ids = list()
425
    all_is_next = list()
426
    all_tail_idxs = list()
427
    all_sentence_labels = list()
428
    all_sentiment_labels = list()
429

430
    add_org = tuple(t.to('cpu') for t in add_org)
431
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
432
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
433

434
    ###
435
    #print("input_ids_",input_ids_.shape)
436
    #print("---")
437
    #print("sentence_ids",sentence_ids.shape)
438
    #print("---")
439
    #print("sentence_label_",sentence_label_.shape)
440
    #exit()
441

442

443
    for id_1, sent in enumerate(sentence_ids):
444
        for id_2, sent_id in enumerate(sent):
445

446
            t1 = data[str(int(sent_id))]['sentence']
447

448
            tokens_a = tokenizer.tokenize(t1)
449

450
            # tokenize
451
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
452

453
            # transform sample to features
454
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
455

456
            all_input_ids.append(torch.tensor(cur_features.input_ids))
457
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
458
            all_input_mask.append(torch.tensor(cur_features.input_mask))
459
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
460
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
461
            all_is_next.append(torch.tensor(0))
462
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
463
            all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
464
            all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
465

466
        all_input_ids.append(input_ids_[id_1])
467
        all_input_ids_org.append(input_ids_org_[id_1])
468
        all_input_mask.append(input_mask_[id_1])
469
        all_segment_ids.append(segment_ids_[id_1])
470
        all_lm_labels_ids.append(lm_label_ids_[id_1])
471
        all_is_next.append(is_next_[id_1])
472
        all_tail_idxs.append(tail_idxs_[id_1])
473
        all_sentence_labels.append(sentence_label_[id_1])
474
        all_sentiment_labels.append(sentiment_label_[id_1])
475

476

477
    cur_tensors = (torch.stack(all_input_ids),
478
                   torch.stack(all_input_ids_org),
479
                   torch.stack(all_input_mask),
480
                   torch.stack(all_segment_ids),
481
                   torch.stack(all_lm_labels_ids),
482
                   torch.stack(all_is_next),
483
                   torch.stack(all_tail_idxs),
484
                   torch.stack(all_sentence_labels),
485
                   torch.stack(all_sentiment_labels)
486
                   )
487

488

489
    return cur_tensors
490

491

492

493

494

495
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
496
    '''
497
    top_k_shape = top_k.indices.shape
498
    sentence_ids = top_k.indices
499
    '''
500
    #top_k_shape = top_k["indices"].shape
501
    #sentence_ids = top_k["indices"]
502

503

504
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
505
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
506

507

508
    #uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
509

510
    all_sentence_binary_label = list()
511
    #all_in_task_rep_comb = list()
512
    all_in_rep_comb = list()
513

514
    for id_1, num in enumerate(sentence_label_):
515
        #print([sentence_label_==num])
516
        #print(type([sentence_label_==num]))
517
        sentence_label_int = (sentence_label_==num).to(torch.long)
518
        #print(sentence_label_int)
519
        #print(sentence_label_int.shape)
520
        #print(in_task_rep[id_1].shape)
521
        #print(in_task_rep.shape)
522
        #exit()
523
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
524
        in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
525
        #print(in_task_rep_append)
526
        #print(in_task_rep_append.shape)
527
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
528
        in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
529
        #print(in_task_rep_comb)
530
        #print(in_task_rep_comb.shape)
531
        #exit()
532
        #sentence_label_int = sentence_label_int.to(torch.float32)
533
        #print(sentence_label_int)
534
        #exit()
535
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
536
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
537
        #print(in_task_rep_comb.shape)
538
        #print(in_domain_rep_comb.shape)
539
        in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
540
        #print(in_rep.shape)
541
        #exit()
542
        all_sentence_binary_label.append(sentence_label_int)
543
        #all_in_task_rep_comb.append(in_task_rep_comb)
544
        all_in_rep_comb.append(in_rep_comb)
545
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
546
    #all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
547
    all_in_rep_comb = torch.stack(all_in_rep_comb)
548

549
    #cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
550
    cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
551

552
    return cur_tensors
553

554

555

556

557
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
558
    '''
559
    top_k_shape = top_k.indices.shape
560
    sentence_ids = top_k.indices
561
    '''
562
    #top_k_shape = top_k["indices"].shape
563
    #sentence_ids = top_k["indices"]
564

565

566
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
567
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
568

569

570
    #uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
571

572
    all_sentence_binary_label = list()
573
    all_in_task_rep_comb = list()
574

575
    for id_1, num in enumerate(sentence_label_):
576
        #print([sentence_label_==num])
577
        #print(type([sentence_label_==num]))
578
        sentence_label_int = (sentence_label_==num).to(torch.long)
579
        #print(sentence_label_int)
580
        #print(sentence_label_int.shape)
581
        #print(in_task_rep[id_1].shape)
582
        #print(in_task_rep.shape)
583
        #exit()
584
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
585
        #print(in_task_rep_append)
586
        #print(in_task_rep_append.shape)
587
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
588
        #print(in_task_rep_comb)
589
        #print(in_task_rep_comb.shape)
590
        #exit()
591
        #sentence_label_int = sentence_label_int.to(torch.float32)
592
        #print(sentence_label_int)
593
        #exit()
594
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
595
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
596
        all_sentence_binary_label.append(sentence_label_int)
597
        all_in_task_rep_comb.append(in_task_rep_comb)
598
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
599
    all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
600

601
    cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
602

603
    return cur_tensors
604

605

606

607

608
class Dataset_noNext(Dataset):
609
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
610

611
        self.vocab_size = tokenizer.vocab_size
612
        self.tokenizer = tokenizer
613
        self.seq_len = seq_len
614
        self.on_memory = on_memory
615
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
616
        self.corpus_path = corpus_path
617
        self.encoding = encoding
618
        self.current_doc = 0  # to avoid random sentence from same doc
619

620
        # for loading samples directly from file
621
        self.sample_counter = 0  # used to keep track of full epochs on file
622
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
623

624
        # for loading samples in memory
625
        self.current_random_doc = 0
626
        self.num_docs = 0
627
        self.sample_to_doc = [] # map sample index to doc and line
628

629
        # load samples into memory
630
        if on_memory:
631
            self.all_docs = []
632
            doc = []
633
            self.corpus_lines = 0
634
            with open(corpus_path, "r", encoding=encoding) as f:
635
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
636
                    line = line.strip()
637
                    if line == "":
638
                        self.all_docs.append(doc)
639
                        doc = []
640
                        #remove last added sample because there won't be a subsequent line anymore in the doc
641
                        self.sample_to_doc.pop()
642
                    else:
643
                        #store as one sample
644
                        sample = {"doc_id": len(self.all_docs),
645
                                  "line": len(doc)}
646
                        self.sample_to_doc.append(sample)
647
                        doc.append(line)
648
                        self.corpus_lines = self.corpus_lines + 1
649

650
            # if last row in file is not empty
651
            if self.all_docs[-1] != doc:
652
                self.all_docs.append(doc)
653
                self.sample_to_doc.pop()
654

655
            self.num_docs = len(self.all_docs)
656

657
        # load samples later lazily from disk
658
        else:
659
            if self.corpus_lines is None:
660
                with open(corpus_path, "r", encoding=encoding) as f:
661
                    self.corpus_lines = 0
662
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
663
                        if line.strip() == "":
664
                            self.num_docs += 1
665
                        else:
666
                            self.corpus_lines += 1
667

668
                    # if doc does not end with empty line
669
                    if line.strip() != "":
670
                        self.num_docs += 1
671

672
            self.file = open(corpus_path, "r", encoding=encoding)
673
            self.random_file = open(corpus_path, "r", encoding=encoding)
674

675
    def __len__(self):
676
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
677
        return self.corpus_lines - self.num_docs - 1
678

679
    def __getitem__(self, item):
680
        cur_id = self.sample_counter
681
        self.sample_counter += 1
682
        if not self.on_memory:
683
            # after one epoch we start again from beginning of file
684
            if cur_id != 0 and (cur_id % len(self) == 0):
685
                self.file.close()
686
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
687

688
        #t1, t2, is_next_label = self.random_sent(item)
689
        t1, is_next_label = self.random_sent(item)
690
        if is_next_label == None:
691
            is_next_label = 0
692

693

694
        #tokens_a = self.tokenizer.tokenize(t1)
695
        tokens_a = tokenizer.tokenize(t1)
696
        '''
697
        if "</s>" in tokens_a:
698
            print("Have more than 1 </s>")
699
            #tokens_a[tokens_a.index("<s>")] = "s"
700
            for i in range(len(tokens_a)):
701
                if tokens_a[i] == "</s>":
702
                    tokens_a[i] = "s"
703
        '''
704
        #tokens_b = self.tokenizer.tokenize(t2)
705

706
        # tokenize
707
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
708

709
        # transform sample to features
710
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
711

712
        cur_tensors = (torch.tensor(cur_features.input_ids),
713
                       torch.tensor(cur_features.input_ids_org),
714
                       torch.tensor(cur_features.input_mask),
715
                       torch.tensor(cur_features.segment_ids),
716
                       torch.tensor(cur_features.lm_label_ids),
717
                       torch.tensor(cur_features.is_next),
718
                       torch.tensor(cur_features.tail_idxs))
719

720
        return cur_tensors
721

722
    def random_sent(self, index):
723
        """
724
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
725
        from one doc. With 50% the second sentence will be a random one from another doc.
726
        :param index: int, index of sample.
727
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
728
        """
729
        t1, t2 = self.get_corpus_line(index)
730
        return t1, None
731

732
    def get_corpus_line(self, item):
733
        """
734
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
735
        :param item: int, index of sample.
736
        :return: (str, str), two subsequent sentences from corpus
737
        """
738
        t1 = ""
739
        t2 = ""
740
        assert item < self.corpus_lines
741
        if self.on_memory:
742
            sample = self.sample_to_doc[item]
743
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
744
            # used later to avoid random nextSentence from same doc
745
            self.current_doc = sample["doc_id"]
746
            return t1, t2
747
            #return t1
748
        else:
749
            if self.line_buffer is None:
750
                # read first non-empty line of file
751
                while t1 == "" :
752
                    t1 = next(self.file).strip()
753
            else:
754
                # use t2 from previous iteration as new t1
755
                t1 = self.line_buffer
756
                # skip empty rows that are used for separating documents and keep track of current doc id
757
                while t1 == "":
758
                    t1 = next(self.file).strip()
759
                    self.current_doc = self.current_doc+1
760
            self.line_buffer = next(self.file).strip()
761

762
        assert t1 != ""
763
        return t1, t2
764

765

766
    def get_random_line(self):
767
        """
768
        Get random line from another document for nextSentence task.
769
        :return: str, content of one line
770
        """
771
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
772
        # corpora. However, just to be careful, we try to make sure that
773
        # the random document is not the same as the document we're processing.
774
        for _ in range(10):
775
            if self.on_memory:
776
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
777
                rand_doc = self.all_docs[rand_doc_idx]
778
                line = rand_doc[random.randrange(len(rand_doc))]
779
            else:
780
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
781
                #pick random line
782
                for _ in range(rand_index):
783
                    line = self.get_next_line()
784
            #check if our picked random line is really from another doc like we want it to be
785
            if self.current_random_doc != self.current_doc:
786
                break
787
        return line
788

789
    def get_next_line(self):
790
        """ Gets next line of random_file and starts over when reaching end of file"""
791
        try:
792
            line = next(self.random_file).strip()
793
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
794
            if line == "":
795
                self.current_random_doc = self.current_random_doc + 1
796
                line = next(self.random_file).strip()
797
        except StopIteration:
798
            self.random_file.close()
799
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
800
            line = next(self.random_file).strip()
801
        return line
802

803

804
class InputExample(object):
805
    """A single training/test example for the language model."""
806

807
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
808
        """Constructs a InputExample.
809
        Args:
810
            guid: Unique id for the example.
811
            tokens_a: string. The untokenized text of the first sequence. For single
812
            sequence tasks, only this sequence must be specified.
813
            tokens_b: (Optional) string. The untokenized text of the second sequence.
814
            Only must be specified for sequence pair tasks.
815
            label: (Optional) string. The label of the example. This should be
816
            specified for train and dev examples, but not for test examples.
817
        """
818
        self.guid = guid
819
        self.tokens_a = tokens_a
820
        self.tokens_b = tokens_b
821
        self.is_next = is_next  # nextSentence
822
        self.lm_labels = lm_labels  # masked words for language model
823

824

825
class InputFeatures(object):
826
    """A single set of features of data."""
827

828
    def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
829
        self.input_ids = input_ids
830
        self.input_ids_org = input_ids_org
831
        self.input_mask = input_mask
832
        self.segment_ids = segment_ids
833
        self.is_next = is_next
834
        self.lm_label_ids = lm_label_ids
835
        self.tail_idxs = tail_idxs
836

837

838
def random_word(tokens, tokenizer):
839
    """
840
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
841
    :param tokens: list of str, tokenized sentence.
842
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
843
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
844
    """
845
    output_label = []
846

847
    for i, token in enumerate(tokens):
848

849
        prob = random.random()
850
        # mask token with 15% probability
851
        if prob < 0.15:
852
            prob /= 0.15
853
            #candidate_id = random.randint(0,tokenizer.vocab_size)
854
            #print(tokenizer.convert_ids_to_tokens(candidate_id))
855

856

857
            # 80% randomly change token to mask token
858
            if prob < 0.8:
859
                #tokens[i] = "[MASK]"
860
                tokens[i] = "<mask>"
861

862
            # 10% randomly change token to random token
863
            elif prob < 0.9:
864
                #tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
865
                #tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
866
                candidate_id = random.randint(0,tokenizer.vocab_size)
867
                w = tokenizer.convert_ids_to_tokens(candidate_id)
868
                '''
869
                if tokens[i] == None:
870
                    candidate_id = 100
871
                    w = tokenizer.convert_ids_to_tokens(candidate_id)
872
                '''
873
                tokens[i] = w
874

875

876
            # -> rest 10% randomly keep current token
877

878
            # append current token to output (we will predict these later)
879
            try:
880
                #output_label.append(tokenizer.vocab[token])
881
                w = tokenizer.convert_tokens_to_ids(token)
882
                if w!= None:
883
                    output_label.append(w)
884
                else:
885
                    print("Have no this tokens in ids")
886
                    exit()
887
            except KeyError:
888
                # For unknown words (should not occur with BPE vocab)
889
                #output_label.append(tokenizer.vocab["<unk>"])
890
                w = tokenizer.convert_tokens_to_ids("<unk>")
891
                output_label.append(w)
892
                logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
893
        else:
894
            # no masking token (will be ignored by loss function later)
895
            output_label.append(-1)
896

897
    return tokens, output_label
898

899

900
def convert_example_to_features(example, max_seq_length, tokenizer):
901
    """
902
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
903
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
904
    :param example: InputExample, containing sentence input as strings and is_next label
905
    :param max_seq_length: int, maximum length of sequence.
906
    :param tokenizer: Tokenizer
907
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
908
    """
909
    #now tokens_a is input_ids
910
    tokens_a = example.tokens_a
911
    tokens_b = example.tokens_b
912
    # Modifies `tokens_a` and `tokens_b` in place so that the total
913
    # length is less than the specified length.
914
    # Account for [CLS], [SEP], [SEP] with "- 3"
915
    #_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
916
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
917

918
    #print(tokens_a)
919
    tokens_a_org = tokens_a.copy()
920
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
921
    #print("----")
922
    #print(tokens_a)
923
    #print(tokens_a_org)
924
    #exit()
925
    #print(t1_label)
926
    #exit()
927
    #tokens_b, t2_label = random_word(tokens_b, tokenizer)
928
    # concatenate lm labels and account for CLS, SEP, SEP
929
    #lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
930
    lm_label_ids = ([-1] + t1_label + [-1])
931

932
    # The convention in BERT is:
933
    # (a) For sequence pairs:
934
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
935
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
936
    # (b) For single sequences:
937
    #  tokens:   [CLS] the dog is hairy . [SEP]
938
    #  type_ids: 0   0   0   0  0     0 0
939
    #
940
    # Where "type_ids" are used to indicate whether this is the first
941
    # sequence or the second sequence. The embedding vectors for `type=0` and
942
    # `type=1` were learned during pre-training and are added to the wordpiece
943
    # embedding vector (and position vector). This is not *strictly* necessary
944
    # since the [SEP] token unambigiously separates the sequences, but it makes
945
    # it easier for the model to learn the concept of sequences.
946
    #
947
    # For classification tasks, the first vector (corresponding to [CLS]) is
948
    # used as as the "sentence vector". Note that this only makes sense because
949
    # the entire model is fine-tuned.
950
    tokens = []
951
    tokens_org = []
952
    segment_ids = []
953
    tokens.append("[CLS]")
954
    tokens_org.append("[CLS]")
955
    segment_ids.append(0)
956
    for i, token in enumerate(tokens_a):
957
        if token!="[SEP]":
958
            tokens.append(tokens_a[i])
959
            tokens_org.append(tokens_a_org[i])
960
            segment_ids.append(0)
961
        else:
962
            tokens.append("s")
963
            tokens_org.append("s")
964
            segment_ids.append(0)
965
    tokens.append("[SEP]")
966
    tokens_org.append("[SEP]")
967
    segment_ids.append(0)
968

969
    #tokens.append("[SEP]")
970
    #segment_ids.append(1)
971

972
    #input_ids = tokenizer.convert_tokens_to_ids(tokens)
973
    input_ids = tokenizer.encode(tokens, add_special_tokens=False)
974
    input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
975
    tail_idxs = len(input_ids)-1
976

977
    #print(input_ids)
978
    input_ids = [w if w!=None else 0 for w in input_ids]
979
    input_ids_org = [w if w!=None else 0 for w in input_ids_org]
980
    #print(input_ids)
981
    #exit()
982

983
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
984
    # tokens are attended to.
985
    input_mask = [1] * len(input_ids)
986

987
    # Zero-pad up to the sequence length.
988
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
989
    while len(input_ids) < max_seq_length:
990
        input_ids.append(pad_id)
991
        input_ids_org.append(pad_id)
992
        input_mask.append(0)
993
        segment_ids.append(0)
994
        lm_label_ids.append(-1)
995

996
    try:
997
        assert len(input_ids) == max_seq_length
998
        assert len(input_ids_org) == max_seq_length
999
        assert len(input_mask) == max_seq_length
1000
        assert len(segment_ids) == max_seq_length
1001
        assert len(lm_label_ids) == max_seq_length
1002
    except:
1003
        print("!!!Warning!!!")
1004
        input_ids = input_ids[:max_seq_length-1]
1005
        if 102 not in input_ids:
1006
            input_ids += [102]
1007
        else:
1008
            input_ids += [pad_id]
1009
        input_ids_org = input_ids_org[:max_seq_length-1]
1010
        if 102 not in input_ids_org:
1011
            input_ids_org += [102]
1012
        else:
1013
            input_ids_org += [pad_id]
1014
        input_mask = input_mask[:max_seq_length-1]+[0]
1015
        segment_ids = segment_ids[:max_seq_length-1]+[0]
1016
        lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1017
    '''
1018
    flag=False
1019
    if len(input_ids) != max_seq_length:
1020
        print(len(input_ids))
1021
        flag=True
1022
    if len(input_ids_org) != max_seq_length:
1023
        print(len(input_ids_org))
1024
        flag=True
1025
    if len(input_mask) != max_seq_length:
1026
        print(len(input_mask))
1027
        flag=True
1028
    if len(segment_ids) != max_seq_length:
1029
        print(len(segment_ids))
1030
        flag=True
1031
    if len(lm_label_ids) != max_seq_length:
1032
        print(len(lm_label_ids))
1033
        flag=True
1034
    if flag == True:
1035
        print("1165")
1036
        exit()
1037
    '''
1038

1039
    '''
1040
    if example.guid < 5:
1041
        logger.info("*** Example ***")
1042
        logger.info("guid: %s" % (example.guid))
1043
        logger.info("tokens: %s" % " ".join(
1044
                [str(x) for x in tokens]))
1045
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1046
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1047
        logger.info(
1048
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1049
        logger.info("LM label: %s " % (lm_label_ids))
1050
        logger.info("Is next sentence label: %s " % (example.is_next))
1051
    '''
1052

1053
    features = InputFeatures(input_ids=input_ids,
1054
                             input_ids_org = input_ids_org,
1055
                             input_mask=input_mask,
1056
                             segment_ids=segment_ids,
1057
                             lm_label_ids=lm_label_ids,
1058
                             is_next=example.is_next,
1059
                             tail_idxs=tail_idxs)
1060
    return features
1061

1062

1063
def main():
1064
    parser = argparse.ArgumentParser()
1065

1066
    parser = get_parameter(parser)
1067

1068
    args = parser.parse_args()
1069

1070
    if args.local_rank == -1 or args.no_cuda:
1071
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1072
        n_gpu = torch.cuda.device_count()
1073
    else:
1074
        torch.cuda.set_device(args.local_rank)
1075
        device = torch.device("cuda", args.local_rank)
1076
        n_gpu = 1
1077
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1078
        torch.distributed.init_process_group(backend='nccl')
1079
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1080
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
1081

1082
    if args.gradient_accumulation_steps < 1:
1083
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1084
                            args.gradient_accumulation_steps))
1085

1086
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1087

1088
    random.seed(args.seed)
1089
    np.random.seed(args.seed)
1090
    torch.manual_seed(args.seed)
1091
    if n_gpu > 0:
1092
        torch.cuda.manual_seed_all(args.seed)
1093

1094
    if not args.do_train:
1095
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1096

1097
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1098
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1099
    if not os.path.exists(args.output_dir):
1100
        os.makedirs(args.output_dir)
1101

1102
    #tokenizer = BertTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1103
    tokenizer = BertTokenizer.from_pretrained(args.pretrain_model)
1104

1105

1106
    #train_examples = None
1107
    num_train_optimization_steps = None
1108
    if args.do_train:
1109
        print("Loading Train Dataset", args.data_dir_indomain)
1110
        #train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1111
        all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1112
        num_train_optimization_steps = int(
1113
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1114
        if args.local_rank != -1:
1115
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1116

1117

1118

1119
    # Prepare model
1120
    model = BertForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1121
    #model = BertForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1122
    model.to(device)
1123

1124

1125

1126
    # Prepare optimizer
1127
    if args.do_train:
1128
        param_optimizer = list(model.named_parameters())
1129
        '''
1130
        for par in param_optimizer:
1131
            print(par[0])
1132
        exit()
1133
        '''
1134
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1135
        optimizer_grouped_parameters = [
1136
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1137
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1138
            ]
1139
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1140
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1141

1142
        if args.fp16:
1143
            try:
1144
                from apex import amp
1145
            except ImportError:
1146
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1147
                exit()
1148

1149
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1150

1151

1152
        if n_gpu > 1:
1153
            model = torch.nn.DataParallel(model)
1154

1155
        if args.local_rank != -1:
1156
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1157

1158

1159

1160
    global_step = 0
1161
    if args.do_train:
1162
        logger.info("***** Running training *****")
1163
        logger.info("  Num examples = %d", len(train_dataset))
1164
        logger.info("  Batch size = %d", args.train_batch_size)
1165
        logger.info("  Num steps = %d", num_train_optimization_steps)
1166

1167
        if args.local_rank == -1:
1168
            train_sampler = RandomSampler(train_dataset)
1169
            #all_type_sentence_sampler = RandomSampler(all_type_sentence)
1170
        else:
1171
            #TODO: check if this works with current data generator from disk that relies on next(file)
1172
            # (it doesn't return item back by index)
1173
            train_sampler = DistributedSampler(train_dataset)
1174
            #all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1175
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1176
        #all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1177

1178
        output_loss_file = os.path.join(args.output_dir, "loss")
1179
        loss_fout = open(output_loss_file, 'w')
1180

1181

1182
        output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1183
        loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1184
        model.train()
1185

1186

1187

1188

1189
        #alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1190
        #alpha = float(1/args.num_train_epochs)
1191
        alpha = float(1)
1192
        #k=64
1193
        #choose_n=24
1194
        no_tune = -1
1195
        #k=16
1196
        #choose_n=16
1197
        k = args.K
1198
        choose_n = args.K
1199
        #k = 10
1200
        #k = 2
1201
        #retrive_gate = args.num_labels_task
1202
        #retrive_gate = len(train_dataset)/100
1203
        retrive_gate = 1
1204
        all_type_sentence_label = list()
1205
        all_previous_sentence_label = list()
1206
        all_type_sentiment_label = list()
1207
        all_previous_sentiment_label = list()
1208
        top_k_all_type = dict()
1209
        bottom_k_all_type = dict()
1210
        for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1211
            tr_loss = 0
1212
            nb_tr_examples, nb_tr_steps = 0, 0
1213
            for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1214

1215

1216
                #######################
1217
                ######################
1218
                ###Init 8 type sentence
1219
                ###Init 2 type sentiment
1220
                if (step == 0) and (epo == 0):
1221
                    #batch_ = tuple(t.to(device) for t in batch_)
1222
                    #all_type_sentence_ = tuple(t.to(device) for t in all_type_sentence)
1223

1224
                    input_ids_ = torch.stack([line[0] for line in all_type_sentence]).to(device)
1225
                    input_ids_org_ = torch.stack([line[1] for line in all_type_sentence]).to(device)
1226
                    input_mask_ = torch.stack([line[2] for line in all_type_sentence]).to(device)
1227
                    segment_ids_ = torch.stack([line[3] for line in all_type_sentence]).to(device)
1228
                    lm_label_ids_ = torch.stack([line[4] for line in all_type_sentence]).to(device)
1229
                    is_next_ = torch.stack([line[5] for line in all_type_sentence]).to(device)
1230
                    tail_idxs_ = torch.stack([line[6] for line in all_type_sentence]).to(device)
1231
                    sentence_label_ = torch.stack([line[7] for line in all_type_sentence]).to(device)
1232
                    sentiment_label_ = torch.stack([line[8] for line in all_type_sentence]).to(device)
1233

1234
                    with torch.no_grad():
1235

1236
                        #in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1237
                        in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1238
                        # Search id from Docs and ranking via (Domain/Task)
1239
                        #query_domain = in_domain_rep_mean.float().to("cpu")
1240
                        query_domain = in_domain_rep.float().to("cpu")
1241
                        query_domain = query_domain.unsqueeze(1)
1242
                        #query_task = in_task_rep_mean.float().to("cpu")
1243
                        query_task = in_task_rep.float().to("cpu")
1244
                        query_task = query_task.unsqueeze(1)
1245
                        #query_domain_task = torch.cat([query_domain,query_task],2)
1246

1247

1248
                        task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1249
                        task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1250
                        task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1251
                        task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1252

1253

1254
                        domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1255
                        domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1256
                        domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1257
                        domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1258

1259

1260
                        #domain_task_binary_classifier_weight, domain_task_binary_classifier_bias = model(func="return_domain_task_binary_classifier")
1261
                        #domain_task_binary_classifier_weight = domain_task_binary_classifier_weight[:int(domain_task_binary_classifier_weight.shape[0]/n_gpu)][:]
1262
                        #domain_task_binary_classifier_bias = domain_task_binary_classifier_bias[:int(domain_task_binary_classifier_bias.shape[0]/n_gpu)][:]
1263
                        #domain_task_binary_classifier = return_Classifier(domain_task_binary_classifier_weight, domain_task_binary_classifier_bias, 768*4, 2)
1264

1265
                        #start = time.time()
1266
                        query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1267
                        query_task = query_task.expand(-1, docs_head.shape[0], -1)
1268
                        #query_domain_task = query_domain_task.expand(-1, docs_head.shape[0], -1)
1269

1270
                        #################
1271
                        #################
1272
                        #Ranking
1273

1274
                        #LeakyReLU = torch.nn.LeakyReLU()
1275
                        #Domain logit
1276
                        '''
1277
                        domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1278
                        domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1279
                        domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1280
                        '''
1281
                        domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1282
                        target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1283
                        #domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1284
                        domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1285

1286
                        #Task logit
1287
                        task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1288
                        #task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1289
                        #target = torch.zeros(task_binary_logit.shape[0], task_binary_logit.shape[1], dtype=torch.long)
1290
                        task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1291

1292
                        #Domain Task logit
1293
                        domain_task_binary_logit = task_binary_logit+domain_binary_logit*0.5
1294

1295
                        ### For paper
1296
                        '''
1297
                        ###Domain
1298
                        ######
1299
                        domain_top_k_all_type = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1300
                        perm = torch.randperm(domain_binary_logit.shape[1])
1301
                        domain_bottom_k_all_type_indices = perm[:k]
1302
                        domain_bottom_k_all_type_values = domain_binary_logit[:,domain_bottom_k_all_type_indices]
1303
                        domain_bottom_k_all_type_indices = torch.stack(args.domain_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1304

1305

1306
                        ####Task
1307
                        task_top_k_all_type = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1308
                        ###Domain+Task
1309
                        domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1310
                        '''
1311
                        ###
1312
                        ###########################
1313
                        ###Performance
1314
                        ###Domain
1315
                        ######
1316
                        domain_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1317
                        ###
1318
                        rand_seed = torch.randint(0,k,(choose_n,))
1319
                        domain_top_k_all_type_indices = domain_top_k_all_type.indices[:,rand_seed]
1320
                        domain_top_k_all_type_values = domain_top_k_all_type.values[:,rand_seed]
1321
                        ###
1322

1323

1324
                        #perm = torch.randperm(domain_task_binary_logit.shape[1])
1325
                        #domain_bottom_k_all_type_indices = perm[:k]
1326
                        #domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1327
                        #domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1328

1329
                        #domain_bottom_k_all_type = torch.topk(domain_task_binary_logit, k*2, dim=1, largest=False, sorted=False)
1330
                        domain_bottom_k_all_type_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1331
                        domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1332
                        domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1333

1334

1335
                        ####Task
1336
                        task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1337
                        ###
1338
                        rand_seed = torch.randint(0,k,(choose_n,))
1339
                        task_top_k_all_type_indices = task_top_k_all_type.indices[:,rand_seed]
1340
                        task_top_k_all_type_values = task_top_k_all_type.values[:,rand_seed]
1341
                        ###
1342

1343

1344
                        ###Domain+Task
1345
                        domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1346
                        ###
1347
                        rand_seed = torch.randint(0,k,(choose_n,))
1348
                        domain_task_top_k_all_type_indices = domain_task_top_k_all_type.indices[:,rand_seed]
1349
                        domain_task_top_k_all_type_values = domain_task_top_k_all_type.values[:,rand_seed]
1350
                        ###
1351

1352

1353
                        ###########################
1354

1355

1356
                        del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1357

1358
                        all_type_sentiment_label = sentiment_label_.to('cpu')
1359

1360

1361
                        domain_bottom_k_all_type = {"values":domain_bottom_k_all_type_values, "indices":domain_bottom_k_all_type_indices}
1362
                        #domain_top_k_all_type = {"values":domain_top_k_all_type.values, "indices":domain_top_k_all_type.indices}
1363
                        domain_top_k_all_type = {"values":domain_top_k_all_type_values, "indices":domain_top_k_all_type_indices}
1364
                        #task_top_k_all_type = {"values":task_top_k_all_type.values, "indices":task_top_k_all_type.indices}
1365
                        task_top_k_all_type = {"values":task_top_k_all_type_values, "indices":task_top_k_all_type_indices}
1366
                        #domain_task_top_k_all_type = {"values":domain_task_top_k_all_type.values, "indices":domain_task_top_k_all_type.indices}
1367
                        domain_task_top_k_all_type = {"values":domain_task_top_k_all_type_values, "indices":domain_task_top_k_all_type_indices}
1368

1369
                ######################
1370
                ######################
1371

1372

1373
                ###Normal mode
1374
                batch_ = tuple(t.to(device) for t in batch_)
1375
                input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1376

1377

1378
                ###
1379
                # Generate query representation
1380
                in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1381

1382

1383
                #if (step%10 == 0) or (sentence_label_.shape[0] != args.train_batch_size):
1384
                if (step%retrive_gate == 0) or (sentiment_label_.shape[0] != args.train_batch_size):
1385

1386
                    with torch.no_grad():
1387
                        query_domain = in_domain_rep.float().to("cpu")
1388
                        query_domain = query_domain.unsqueeze(1)
1389
                        #query_task = in_task_rep_mean.float().to("cpu")
1390
                        query_task = in_task_rep.float().to("cpu")
1391
                        query_task = query_task.unsqueeze(1)
1392
                        query_domain_task = torch.cat([query_domain,query_task],2)
1393

1394

1395
                        task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1396
                        task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1397
                        task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1398
                        task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1399

1400

1401
                        domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1402
                        domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1403
                        domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1404
                        domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1405

1406

1407
                        #domain_task_binary_classifier_weight, domain_task_binary_classifier_bias = model(func="return_domain_task_binary_classifier")
1408
                        #domain_task_binary_classifier_weight = domain_task_binary_classifier_weight[:int(domain_task_binary_classifier_weight.shape[0]/n_gpu)][:]
1409
                        #domain_task_binary_classifier_bias = domain_task_binary_classifier_bias[:int(domain_task_binary_classifier_bias.shape[0]/n_gpu)][:]
1410
                        #domain_task_binary_classifier = return_Classifier(domain_task_binary_classifier_weight, domain_task_binary_classifier_bias, 768*4, 2)
1411

1412
                        #start = time.time()
1413
                        #query_domain = query_domain.expand(-1, docs.shape[0], -1)
1414
                        query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1415
                        #query_task = query_task.expand(-1, docs.shape[0], -1)
1416
                        query_task = query_task.expand(-1, docs_head.shape[0], -1)
1417
                        #print(docs_head.shape)
1418
                        #print(query_domain_task.shape)
1419
                        #exit()
1420
                        #query_domain_task = query_domain_task.expand(-1, docs_head.shape[0], -1)
1421

1422
                        #################
1423
                        #################
1424
                        #Ranking
1425

1426
                        #LeakyReLU = torch.nn.LeakyReLU()
1427
                        #Domain logit
1428
                        '''
1429
                        domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1430
                        domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1431
                        domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1432
                        '''
1433
                        domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1434
                        target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1435
                        #domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1436
                        domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1437

1438
                        #Task logit
1439
                        task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1440
                        #task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1441
                        task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1442

1443
                        #Domain Task logit
1444
                        domain_task_binary_logit = task_binary_logit + domain_binary_logit*0.5
1445

1446
                        ####################
1447
                        ###paper
1448
                        ###Domaine
1449
                        ######
1450
                        #[batch_size, 36603]
1451
                        domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1452
                        ###
1453
                        rand_seed = torch.randint(0,k,(choose_n,))
1454
                        domain_top_k_indices = domain_top_k.indices[:,rand_seed]
1455
                        domain_top_k_values = domain_top_k.values[:,rand_seed]
1456
                        ###
1457

1458
                        '''
1459
                        perm = torch.randperm(domain_binary_logit.shape[1])
1460
                        domain_bottom_k_indices = perm[:k]
1461
                        domain_bottom_k_values = domain_binary_logit[:,domain_bottom_k_indices]
1462
                        domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1463
                        '''
1464

1465
                        #domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=False, sorted=False)
1466
                        domain_bottom_k_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1467
                        domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1468
                        domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1469

1470

1471
                        task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1472
                        ###
1473
                        #rand_seed = torch.randint(0,k,(choose_n,))
1474
                        task_top_k_indices = task_top_k.indices[:,rand_seed]
1475
                        task_top_k_values = task_top_k.values[:,rand_seed]
1476
                        ###
1477

1478

1479
                        domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1480
                        #rand_seed = torch.randint(0,k,(choose_n,))
1481
                        domain_task_top_k_indices = domain_task_top_k.indices[:,rand_seed]
1482
                        domain_task_top_k_values = domain_task_top_k.values[:,rand_seed]
1483

1484

1485
                        ####################
1486
                        '''
1487
                        ###Performance
1488
                        domain_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1489
                        perm = torch.randperm(domain_task_binary_logit.shape[1])
1490
                        domain_bottom_k_indices = perm[:k]
1491
                        domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1492
                        domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1493
                        task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1494
                        domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1495
                        '''
1496
                        ####################
1497

1498

1499
                        del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1500

1501
                        all_previous_sentiment_label = sentiment_label_.to('cpu')
1502

1503
                        ######
1504

1505

1506
                        domain_bottom_k = {"values":domain_bottom_k_values, "indices":domain_bottom_k_indices}
1507
                        #domain_top_k = {"values":domain_top_k.values, "indices":domain_top_k.indices}
1508
                        domain_top_k = {"values":domain_top_k_values, "indices":domain_top_k_indices}
1509
                        #task_top_k = {"values":task_top_k.values, "indices":task_top_k.indices}
1510
                        task_top_k = {"values":task_top_k_values, "indices":task_top_k_indices}
1511
                        #domain_task_top_k = {"values":domain_task_top_k.values, "indices":domain_task_top_k.indices}
1512
                        domain_task_top_k = {"values":domain_task_top_k_values, "indices":domain_task_top_k_indices}
1513

1514

1515

1516

1517
                        domain_bottom_k_previous = {"values":torch.cat((domain_bottom_k["values"], domain_bottom_k_all_type["values"]),0), "indices":torch.cat((domain_bottom_k["indices"], domain_bottom_k_all_type["indices"]),0)}
1518
                        domain_top_k_previous = {"values":torch.cat((domain_top_k["values"], domain_top_k_all_type["values"]),0), "indices":torch.cat((domain_top_k["indices"], domain_top_k_all_type["indices"]),0)}
1519
                        task_top_k_previous = {"values":torch.cat((task_top_k["values"], task_top_k_all_type["values"]),0), "indices":torch.cat((task_top_k["indices"], task_top_k_all_type["indices"]),0)}
1520
                        domain_task_top_k_previous = {"values":torch.cat((domain_task_top_k["values"], domain_task_top_k_all_type["values"]),0), "indices":torch.cat((domain_task_top_k["indices"], domain_task_top_k_all_type["indices"]),0)}
1521

1522
                        all_previous_sentiment_label = torch.cat((all_previous_sentiment_label, all_type_sentiment_label))
1523
                else:
1524
                    ###Need to fix --> choice
1525
                    used_idx = torch.tensor([random.choice(((all_previous_sentiment_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentiment_label_])
1526
                    #top_k = {"values":top_k_previous["values"].index_select(0,used_idx), "indices":top_k_previous["indices"].index_select(0,used_idx)}
1527
                    domain_top_k = {"values":domain_top_k_previous["values"].index_select(0,used_idx), "indices":domain_top_k_previous["indices"].index_select(0,used_idx)}
1528
                    task_top_k = {"values":task_top_k_previous["values"].index_select(0,used_idx), "indices":task-top_k_previous["indices"].index_select(0,used_idx)}
1529
                    domain_task_top_k = {"values":domain_task_top_k_previous["values"].index_select(0,used_idx), "indices":domain_task_top_k_previous["indices"].index_select(0,used_idx)}
1530

1531
                    #bottom_k = {"values":bottom_k_previous["values"].index_select(0,used_idx), "indices":bottom_k_previous["indices"].index_select(0,used_idx)}
1532
                    domaion_bottom_k = {"values":domain_bottom_k_previous["values"].index_select(0,used_idx), "indices":domain_bottom_k_previous["indices"].index_select(0,used_idx)}
1533

1534

1535

1536

1537

1538
                if epo < no_tune:
1539

1540
                    #################
1541
                    #################
1542
                    #Domain Binary Classifier - Outdomain
1543
                    #batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1544
                    batch = AugmentationData_Domain(domain_top_k, domain_bottom_k, tokenizer, args.max_seq_length)
1545
                    batch = tuple(t.to(device) for t in batch)
1546
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1547

1548
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1549
                    #print("======")
1550
                    #print(domain_top_k["indices"].shape)
1551
                    #print(input_ids_org.shape)
1552
                    #print(out_domain_rep_tail.shape)
1553
                    #print(in_domain_rep.shape)
1554
                    #print("======")
1555
                    ############Construct constrive instances
1556
                    comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1557
                    in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1558
                    comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1559

1560
                    mix_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id, use_detach=False)
1561
                    ############
1562

1563

1564
                    #################
1565
                    #################
1566
                    ###Update_rep
1567
                    indices = domain_top_k["indices"].reshape(domain_top_k["indices"].shape[0]*domain_top_k["indices"].shape[1])
1568
                    indices_ = domain_bottom_k["indices"].reshape(domain_bottom_k["indices"].shape[0]*domain_bottom_k["indices"].shape[1])
1569
                    indices = torch.cat([indices,indices_],0)
1570

1571
                    out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1572
                    out_domain_rep_head.requires_grad=True
1573

1574
                    out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1575
                    out_domain_rep_tail.requires_grad=True
1576

1577

1578
                    with torch.no_grad():
1579
                        #Exam here!!!
1580
                        try:
1581
                            docs_head.index_copy_(0, indices, out_domain_rep_head)
1582
                            docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1583
                        except:
1584
                            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
1585
                            print("head",out_domain_rep_head.shape)
1586
                            print("tail",out_domain_rep_head.shape)
1587
                            print("doc_h",docs_head.shape)
1588
                            print("doc_t",docs_tail.shape)
1589
                            print("ind",indices.shape)
1590

1591

1592

1593
                    #################
1594
                    #################
1595
                    #Task Binary Classifier    in domain
1596
                    #Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1597
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1598
                    batch = tuple(t.to(device) for t in batch)
1599
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1600
                    in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1601

1602

1603
                    #################
1604
                    #################
1605
                    #Train Task org - finetune
1606
                    #split into: in_dom and query_  --> different weight
1607
                    task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1608

1609

1610
                    #################
1611
                    #################
1612
                    #Task Level   including outdomain
1613
                    batch = AugmentationData_Task(task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1614
                    batch = tuple(t.to(device) for t in batch)
1615
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1616
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1617
                    ###
1618
                    batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head)
1619
                    batch = tuple(t.to(device) for t in batch)
1620
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1621
                    out_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1622
                    ###
1623

1624
                    #################
1625
                    #################
1626
                    ###Update_rep
1627
                    indices = task_top_k["indices"].reshape(task_top_k["indices"].shape[0]*task_top_k["indices"].shape[1])
1628

1629
                    out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1630
                    out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1631
                    out_domain_rep_head.requires_grad=True
1632

1633
                    out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1634
                    out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1635
                    out_domain_rep_tail.requires_grad=True
1636

1637
                    with torch.no_grad():
1638
                        try:
1639
                            docs_head.index_copy_(0, indices, out_domain_rep_head)
1640
                            docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1641
                        except:
1642
                            print("head",out_domain_rep_head.shape)
1643
                            print("head",out_domain_rep_head.get_device())
1644
                            print("tail",out_domain_rep_head.shape)
1645
                            print("tail",out_domain_rep_head.get_device())
1646
                            print("doc_h",docs_head.shape)
1647
                            print("doc_h",docs_head.get_device())
1648
                            print("doc_t",docs_tail.shape)
1649
                            print("doc_t",docs_tail.get_device())
1650
                            print("ind",indices.shape)
1651
                            print("ind",indices.get_device())
1652

1653
                    ##############################
1654
                    ##############################
1655

1656
                    #################
1657
                    #################
1658
                    #Domain-Task Level (Out-domain)
1659
                    batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1660
                    batch = tuple(t.to(device) for t in batch)
1661
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1662
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1663
                    ###
1664
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1665
                    batch = tuple(t.to(device) for t in batch)
1666
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1667
                    out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1668
                    ###
1669

1670

1671
                    #Domain-Task Level (in-domain)
1672
                    ###
1673
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1674
                    batch = tuple(t.to(device) for t in batch)
1675
                    in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1676
                    in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1677
                    ###
1678

1679

1680
                    #################
1681
                    #################
1682
                    ###Update_rep
1683
                    indices = domain_task_top_k["indices"].reshape(domain_task_top_k["indices"].shape[0]*domain_task_top_k["indices"].shape[1])
1684

1685
                    out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1686
                    out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1687
                    out_domain_rep_head.requires_grad=True
1688

1689
                    out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1690
                    out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1691
                    out_domain_rep_tail.requires_grad=True
1692

1693
                    with torch.no_grad():
1694
                        try:
1695
                            docs_head.index_copy_(0, indices, out_domain_rep_head)
1696
                            docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1697
                        except:
1698
                            print("head",out_domain_rep_head.shape)
1699
                            print("head",out_domain_rep_head.get_device())
1700
                            print("tail",out_domain_rep_head.shape)
1701
                            print("tail",out_domain_rep_head.get_device())
1702
                            print("doc_h",docs_head.shape)
1703
                            print("doc_h",docs_head.get_device())
1704
                            print("doc_t",docs_tail.shape)
1705
                            print("doc_t",docs_tail.get_device())
1706
                            print("ind",indices.shape)
1707
                            print("ind",indices.get_device())
1708

1709
                    ##############################
1710
                    ##############################
1711
                else:
1712

1713
                    #################
1714
                    #Domain-Task Level (Out-domain)
1715
                    batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1716
                    batch = tuple(t.to(device) for t in batch)
1717
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1718
                    #out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1719
                    task_loss_out, class_logit_out = model(input_ids_org=input_ids_org, sentence_label=sentiment_label, attention_mask=input_mask, func="task_class")
1720
                    ###
1721
                    '''
1722
                    #batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1723
                    batch = tuple(t.to(device) for t in batch)
1724
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1725
                    out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1726
                    ###
1727

1728
                    #Domain-Task Level (in-domain)
1729
                    ###
1730
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1731
                    batch = tuple(t.to(device) for t in batch)
1732
                    in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1733
                    in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1734
                    '''
1735
                    ###
1736

1737
                    ###Finetune
1738
                    task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1739

1740

1741

1742

1743

1744
                ############################################
1745
                ############################################
1746
                if epo < no_tune:
1747
                    if n_gpu > 1:
1748
                        #loss = mix_domain_binary_loss.mean()*0.5 + (in_task_binary_loss.mean() + out_task_binary_loss.mean())*0.5 + task_loss_org.mean() + out_domain_task_binary_loss
1749
                        loss = task_loss_out.mean() + task_loss_org.mean()
1750
                    else:
1751
                        #loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1752
                        print("No Using GPU")
1753
                else:
1754
                    if n_gpu > 1:
1755
                        loss = task_loss_out.mean() + task_loss_org.mean()
1756
                        #loss = task_loss_org.mean() + (in_domain_task_binary_loss.mean()+out_domain_task_binary_loss.mean())/2
1757
                    else:
1758
                        #loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1759
                        print("No Using GPU")
1760

1761

1762
                if args.gradient_accumulation_steps > 1:
1763
                    loss = loss / args.gradient_accumulation_steps
1764
                if args.fp16:
1765
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
1766
                        scaled_loss.backward()
1767
                else:
1768
                    loss.backward()
1769

1770
                ###
1771
                loss_fout.write("{}\n".format(loss.item()))
1772
                ###
1773

1774
                ###
1775
                #loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1776
                ###
1777

1778
                tr_loss += loss.item()
1779
                #nb_tr_examples += input_ids.size(0)
1780
                nb_tr_examples += input_ids_.size(0)
1781
                nb_tr_steps += 1
1782
                if (step + 1) % args.gradient_accumulation_steps == 0:
1783
                    if args.fp16:
1784
                        # modify learning rate with special warm up BERT uses
1785
                        # if args.fp16 is False, BertAdam is used that handles this automatically
1786
                        #lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1787
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1788
                    ###
1789
                    else:
1790
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1791
                    ###
1792

1793
                    optimizer.step()
1794
                    ###
1795
                    scheduler.step()
1796
                    ###
1797
                    #optimizer.zero_grad()
1798
                    model.zero_grad()
1799
                    global_step += 1
1800

1801

1802
            if epo < -1:
1803
                continue
1804
            else:
1805
                model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1806
                #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1807
                output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(epo))
1808
                torch.save(model_to_save.state_dict(), output_model_file)
1809

1810
        loss_fout.close()
1811

1812
        # Save a trained model
1813
        logger.info("** ** * Saving fine - tuned model ** ** * ")
1814
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1815
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1816
        if args.do_train:
1817
            torch.save(model_to_save.state_dict(), output_model_file)
1818

1819

1820

1821
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1822
    """Truncates a sequence pair in place to the maximum length."""
1823

1824
    # This is a simple heuristic which will always truncate the longer sequence
1825
    # one token at a time. This makes more sense than truncating an equal percent
1826
    # of tokens from each, since if one sequence is very short then each token
1827
    # that's truncated likely contains more information than a longer sequence.
1828
    while True:
1829
        #total_length = len(tokens_a) + len(tokens_b)
1830
        total_length = len(tokens_a)
1831
        if total_length <= max_length:
1832
            break
1833
        else:
1834
            tokens_a.pop()
1835

1836

1837
def accuracy(out, labels):
1838
    outputs = np.argmax(out, axis=1)
1839
    return np.sum(outputs == labels)
1840

1841

1842
if __name__ == "__main__":
1843
    main()
1844

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.