CSS-LM

Форк
0
/
retrieve.py 
1871 строка · 87.9 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25
import json
26
import time
27
from torch.autograd import Variable
28
import torch.nn.functional as F
29

30
import numpy as np
31
import torch
32
from torch.utils.data import DataLoader, Dataset, RandomSampler
33
from torch.utils.data.distributed import DistributedSampler
34
from tqdm import tqdm, trange
35
from torch.nn import CrossEntropyLoss
36

37
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
38
#from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
39
from transformers.modeling_roberta_updateRep_self import RobertaForMaskedLMDomainTask
40
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
41

42
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
43
                    datefmt='%m/%d/%Y %H:%M:%S',
44
                    level=logging.INFO)
45
logger = logging.getLogger(__name__)
46

47

48
ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
49

50

51
def get_parameter(parser):
52

53
    ## Required parameters
54
    parser.add_argument("--data_dir_indomain",
55
                        default=None,
56
                        type=str,
57
                        required=True,
58
                        help="The input train corpus.(In Domain)")
59
    parser.add_argument("--data_dir_outdomain",
60
                        default=None,
61
                        type=str,
62
                        required=True,
63
                        help="The input train corpus.(Out Domain)")
64
    parser.add_argument("--pretrain_model", default=None, type=str, required=True,
65
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
66
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
67
    parser.add_argument("--output_dir",
68
                        default=None,
69
                        type=str,
70
                        required=True,
71
                        help="The output directory where the model checkpoints will be written.")
72
    parser.add_argument("--augment_times",
73
                        default=None,
74
                        type=int,
75
                        required=True,
76
                        help="Default batch_size/augment_times to save model")
77
    ## Other parameters
78
    parser.add_argument("--max_seq_length",
79
                        default=128,
80
                        type=int,
81
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
82
                             "Sequences longer than this will be truncated, and sequences shorter \n"
83
                             "than this will be padded.")
84
    parser.add_argument("--do_train",
85
                        action='store_true',
86
                        help="Whether to run training.")
87
    parser.add_argument("--train_batch_size",
88
                        default=32,
89
                        type=int,
90
                        help="Total batch size for training.")
91
    parser.add_argument("--learning_rate",
92
                        default=3e-5,
93
                        type=float,
94
                        help="The initial learning rate for Adam.")
95
    parser.add_argument("--num_train_epochs",
96
                        default=3.0,
97
                        type=float,
98
                        help="Total number of training epochs to perform.")
99
    parser.add_argument("--warmup_proportion",
100
                        default=0.1,
101
                        type=float,
102
                        help="Proportion of training to perform linear learning rate warmup for. "
103
                             "E.g., 0.1 = 10%% of training.")
104
    parser.add_argument("--no_cuda",
105
                        action='store_true',
106
                        help="Whether not to use CUDA when available")
107
    parser.add_argument("--on_memory",
108
                        action='store_true',
109
                        help="Whether to load train samples into memory or use disk")
110
    parser.add_argument("--do_lower_case",
111
                        action='store_true',
112
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
113
    parser.add_argument("--local_rank",
114
                        type=int,
115
                        default=-1,
116
                        help="local_rank for distributed training on gpus")
117
    parser.add_argument('--seed',
118
                        type=int,
119
                        default=42,
120
                        help="random seed for initialization")
121
    parser.add_argument('--gradient_accumulation_steps',
122
                        type=int,
123
                        default=1,
124
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
125
    parser.add_argument('--fp16',
126
                        action='store_true',
127
                        help="Whether to use 16-bit float precision instead of 32-bit")
128
    parser.add_argument('--loss_scale',
129
                        type = float, default = 0,
130
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
131
                        "0 (default value): dynamic loss scaling.\n"
132
                        "Positive power of 2: static loss scaling value.\n")
133
    ####
134
    parser.add_argument("--num_labels_task",
135
                        default=None, type=int,
136
                        required=True,
137
                        help="num_labels_task")
138
    parser.add_argument("--weight_decay",
139
                        default=0.0,
140
                        type=float,
141
                        help="Weight decay if we apply some.")
142
    parser.add_argument("--adam_epsilon",
143
                        default=1e-8,
144
                        type=float,
145
                        help="Epsilon for Adam optimizer.")
146
    parser.add_argument("--max_grad_norm",
147
                        default=1.0,
148
                        type=float,
149
                        help="Max gradient norm.")
150
    parser.add_argument('--fp16_opt_level',
151
                        type=str,
152
                        default='O1',
153
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
154
                             "See details at https://nvidia.github.io/apex/amp.html")
155
    parser.add_argument("--task",
156
                        default=0,
157
                        type=int,
158
                        required=True,
159
                        help="Choose Task")
160
    parser.add_argument("--K",
161
                        default=None,
162
                        type=int,
163
                        required=True,
164
                        help="K size")
165
    ####
166
    return parser
167

168

169
def return_Classifier(weight, bias, dim_in, dim_out):
170
    #LeakyReLU = torch.nn.LeakyReLU
171
    classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
172
    #print(classifier)
173
    #print(classifier.weight)
174
    #print(classifier.weight.shape)
175
    #print(classifier.weight.data)
176
    #print(classifier.weight.data.shape)
177
    #print("---")
178
    classifier.weight.data = weight.to("cpu")
179
    classifier.bias.data = bias.to("cpu")
180
    classifier.requires_grad=False
181
    #print(classifier)
182
    #print(classifier.weight)
183
    #print(classifier.weight.shape)
184
    #print("---")
185
    #exit()
186
    #print(classifier)
187
    #exit()
188
    #logit = LeakyReLU(classifier)
189
    return classifier
190

191

192
def load_GeneralDomain(dir_data_out):
193

194
    ###
195
    print("===========")
196
    print("Load CLS.pt and train.json")
197
    print("-----------")
198
    docs_head = torch.load(dir_data_out+"train_head.pt")
199
    docs_tail = torch.load(dir_data_out+"train_tail.pt")
200
    print("CLS.pt Done")
201
    print(docs_head.shape)
202
    print(docs_tail.shape)
203
    print("-----------")
204
    with open(dir_data_out+"train.json") as file:
205
        data = json.load(file)
206
    print("train.json Done")
207
    print("===========")
208
    docs_tail_head = torch.cat([docs_tail, docs_head],2)
209
    return docs_tail_head, docs_head, docs_tail, data
210
    ###
211

212

213
parser = argparse.ArgumentParser()
214
parser = get_parameter(parser)
215
args = parser.parse_args()
216
#print(args.data_dir_outdomain)
217
#exit()
218

219
docs_tail_head, docs_head, docs_tail, data = load_GeneralDomain(args.data_dir_outdomain)
220
######
221
if docs_head.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
222
    #last <s>
223
    #docs = docs[:,0,:].unsqueeze(1)
224
    #mean 13 layers <s>
225
    docs_head = docs_head.mean(1).unsqueeze(1)
226
    print(docs_head.shape)
227
else:
228
    print(docs_head.shape)
229
if docs_tail.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
230
    #last <s>
231
    #docs = docs[:,0,:].unsqueeze(1)
232
    #mean 13 layers <s>
233
    docs_tail = docs_tail.mean(1).unsqueeze(1)
234
    print(docs_tail.shape)
235
else:
236
    print(docs_tail.shape)
237
######
238

239
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
240
    ###Open
241
    with open(data_dir_indomain+"train.json") as file:
242
        data = json.load(file)
243

244
    ###Preprocess
245
    num_label_list = list()
246
    label_sentence_dict = dict()
247
    num_sentiment_label_list = list()
248
    sentiment_label_dict = dict()
249
    for line in data:
250
        #line["sentence"]
251
        #line["aspect"]
252
        #line["sentiment"]
253
        num_sentiment_label_list.append(line["sentiment"])
254
        #num_label_list.append(line["aspect"])
255
        num_label_list.append(line["sentiment"])
256

257
    num_label = sorted(list(set(num_label_list)))
258
    label_map = {label : i for i , label in enumerate(num_label)}
259
    num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
260
    sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
261
    print("=======")
262
    print("label_map:")
263
    print(label_map)
264
    print("=======")
265
    print("=======")
266
    print("sentiment_label_map:")
267
    print(sentiment_label_map)
268
    print("=======")
269

270
    ###Create data: 1 choosed data along with the rest of 7 class data
271

272
    '''
273
    all_input_ids = list()
274
    all_input_mask = list()
275
    all_segment_ids = list()
276
    all_lm_labels_ids = list()
277
    all_is_next = list()
278
    all_tail_idxs = list()
279
    all_sentence_labels = list()
280
    '''
281
    cur_tensors_list = list()
282
    #print(list(label_map.values()))
283
    candidate_label_list = list(label_map.values())
284
    candidate_sentiment_label_list = list(sentiment_label_map.values())
285
    all_type_sentence = [0]*len(candidate_label_list)
286
    all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
287
    for line in data:
288
        #line["sentence"]
289
        #line["aspect"]
290
        sentiment = line["sentiment"]
291
        sentence = line["sentence"]
292
        #label = line["aspect"]
293
        label = line["sentiment"]
294

295

296
        tokens_a = tokenizer.tokenize(sentence)
297
        #input_ids = tokenizer.encode(sentence, add_special_tokens=False)
298
        '''
299
        if "</s>" in tokens_a:
300
            print("Have more than 1 </s>")
301
            #tokens_a[tokens_a.index("<s>")] = "s"
302
            for i in range(len(tokens_a)):
303
                if tokens_a[i] == "</s>":
304
                    tokens_a[i] == "s"
305
        '''
306

307

308
        # tokenize
309
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
310
        # transform sample to features
311
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
312

313
        cur_tensors = (torch.tensor(cur_features.input_ids),
314
                       torch.tensor(cur_features.input_ids_org),
315
                       torch.tensor(cur_features.input_mask),
316
                       torch.tensor(cur_features.segment_ids),
317
                       torch.tensor(cur_features.lm_label_ids),
318
                       torch.tensor(0),
319
                       torch.tensor(cur_features.tail_idxs),
320
                       torch.tensor(label_map[label]),
321
                       torch.tensor(sentiment_label_map[sentiment]))
322

323
        cur_tensors_list.append(cur_tensors)
324

325
        ###
326
        if label_map[label] in candidate_label_list:
327
            all_type_sentence[label_map[label]]=cur_tensors
328
            candidate_label_list.remove(label_map[label])
329

330
        if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
331
            #print("----")
332
            #print(sentiment_label_map[sentiment])
333
            #print("----")
334
            all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
335
            candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
336
        ###
337

338

339

340

341
    return all_type_sentiment_sentence, cur_tensors_list
342

343

344

345
def AugmentationData_Domain(bottom_k, top_k, tokenizer, max_seq_length):
346
    #top_k_shape = top_k.indices.shape
347
    #ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
348
    top_k_shape = top_k["indices"].shape
349
    ids_pos = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
350
    #ids = top_k["indices"]
351

352
    bottom_k_shape = bottom_k["indices"].shape
353
    ids_neg = bottom_k["indices"].reshape(bottom_k_shape[0]*bottom_k_shape[1]).tolist()
354

355
    #print(ids_pos)
356
    #print(ids_neg)
357
    #exit()
358

359
    ids = ids_pos+ids_neg
360

361

362
    all_input_ids = list()
363
    all_input_ids_org = list()
364
    all_input_mask = list()
365
    all_segment_ids = list()
366
    all_lm_labels_ids = list()
367
    all_is_next = list()
368
    all_tail_idxs = list()
369
    all_id_domain = list()
370

371
    for id, i in enumerate(ids):
372
        t1 = data[str(i)]['sentence']
373

374
        #tokens_a = tokenizer.tokenize(t1)
375
        tokens_a = tokenizer.tokenize(t1)
376
        '''
377
        if "</s>" in tokens_a:
378
            print("Have more than 1 </s>")
379
            #tokens_a[tokens_a.index("<s>")] = "s"
380
            for i in range(len(tokens_a)):
381
                if tokens_a[i] == "</s>":
382
                    tokens_a[i] = "s"
383
        '''
384

385
        # tokenize
386
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
387

388
        # transform sample to features
389
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
390

391
        all_input_ids.append(torch.tensor(cur_features.input_ids))
392
        all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
393
        all_input_mask.append(torch.tensor(cur_features.input_mask))
394
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
395
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
396
        all_is_next.append(torch.tensor(0))
397
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
398
        if i in ids_neg:
399
            all_id_domain.append(torch.tensor(0))
400
        elif i in ids_pos:
401
            all_id_domain.append(torch.tensor(1))
402

403

404
    cur_tensors = (torch.stack(all_input_ids),
405
                   torch.stack(all_input_ids_org),
406
                   torch.stack(all_input_mask),
407
                   torch.stack(all_segment_ids),
408
                   torch.stack(all_lm_labels_ids),
409
                   torch.stack(all_is_next),
410
                   torch.stack(all_tail_idxs),
411
                   torch.stack(all_id_domain))
412

413
    return cur_tensors
414

415

416
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None, show_type=None):
417
    top_k_shape = top_k["indices"].shape
418
    sentence_ids = top_k["indices"]
419

420
    all_input_ids = list()
421
    all_input_ids_org = list()
422
    all_input_mask = list()
423
    all_segment_ids = list()
424
    all_lm_labels_ids = list()
425
    all_is_next = list()
426
    all_tail_idxs = list()
427
    all_sentence_labels = list()
428
    all_sentiment_labels = list()
429

430
    add_org = tuple(t.to('cpu') for t in add_org)
431
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
432
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
433

434
    ###
435
    #print("input_ids_",input_ids_.shape)
436
    #print("---")
437
    #print("sentence_ids",sentence_ids.shape)
438
    #print("---")
439
    #print("sentence_label_",sentence_label_.shape)
440
    #exit()
441

442

443
    for id_1, sent in enumerate(sentence_ids):
444
        #Domain-Task Level (in-domain)
445
        ###
446
        if int(id_1) in show_type and list(input_ids_org_[id_1]).index(2)>10:
447
            print("=======================")
448
            print("Instance:")
449
            print(tokenizer.decode(input_ids_org_[id_1]))
450
            print(int(sentence_label_[id_1]))
451
            print("-----------------------")
452
        for id_2, sent_id in enumerate(sent):
453
            t1 = data[str(int(sent_id))]['sentence']
454

455
            if int(id_1) in show_type and list(input_ids_org_[id_1]).index(2)>10:
456
                print("-----------------------")
457
                print("Retrieve:")
458
                print(t1)
459
                print("-----------------------")
460

461

462

463
            tokens_a = tokenizer.tokenize(t1)
464

465
            # tokenize
466
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
467

468
            # transform sample to features
469
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
470

471
            all_input_ids.append(torch.tensor(cur_features.input_ids))
472
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
473
            all_input_mask.append(torch.tensor(cur_features.input_mask))
474
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
475
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
476
            all_is_next.append(torch.tensor(0))
477
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
478
            all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
479
            all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
480

481
        all_input_ids.append(input_ids_[id_1])
482
        all_input_ids_org.append(input_ids_org_[id_1])
483
        all_input_mask.append(input_mask_[id_1])
484
        all_segment_ids.append(segment_ids_[id_1])
485
        all_lm_labels_ids.append(lm_label_ids_[id_1])
486
        all_is_next.append(is_next_[id_1])
487
        all_tail_idxs.append(tail_idxs_[id_1])
488
        all_sentence_labels.append(sentence_label_[id_1])
489
        all_sentiment_labels.append(sentiment_label_[id_1])
490

491

492
    cur_tensors = (torch.stack(all_input_ids),
493
                   torch.stack(all_input_ids_org),
494
                   torch.stack(all_input_mask),
495
                   torch.stack(all_segment_ids),
496
                   torch.stack(all_lm_labels_ids),
497
                   torch.stack(all_is_next),
498
                   torch.stack(all_tail_idxs),
499
                   torch.stack(all_sentence_labels),
500
                   torch.stack(all_sentiment_labels)
501
                   )
502

503

504
    return cur_tensors
505

506

507

508

509

510
def AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None, in_domain_rep=None):
511
    '''
512
    top_k_shape = top_k.indices.shape
513
    sentence_ids = top_k.indices
514
    '''
515
    #top_k_shape = top_k["indices"].shape
516
    #sentence_ids = top_k["indices"]
517

518

519
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
520
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
521

522

523
    #uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
524

525
    all_sentence_binary_label = list()
526
    #all_in_task_rep_comb = list()
527
    all_in_rep_comb = list()
528

529
    for id_1, num in enumerate(sentence_label_):
530
        #print([sentence_label_==num])
531
        #print(type([sentence_label_==num]))
532
        sentence_label_int = (sentence_label_==num).to(torch.long)
533
        #print(sentence_label_int)
534
        #print(sentence_label_int.shape)
535
        #print(in_task_rep[id_1].shape)
536
        #print(in_task_rep.shape)
537
        #exit()
538
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
539
        in_domain_rep_append = in_domain_rep[id_1].unsqueeze(0).expand(in_domain_rep.shape[0],-1)
540
        #print(in_task_rep_append)
541
        #print(in_task_rep_append.shape)
542
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
543
        in_domain_rep_comb = torch.cat((in_domain_rep_append,in_domain_rep),-1)
544
        #print(in_task_rep_comb)
545
        #print(in_task_rep_comb.shape)
546
        #exit()
547
        #sentence_label_int = sentence_label_int.to(torch.float32)
548
        #print(sentence_label_int)
549
        #exit()
550
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
551
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
552
        #print(in_task_rep_comb.shape)
553
        #print(in_domain_rep_comb.shape)
554
        in_rep_comb = torch.cat([in_domain_rep_comb,in_task_rep_comb],-1)
555
        #print(in_rep.shape)
556
        #exit()
557
        all_sentence_binary_label.append(sentence_label_int)
558
        #all_in_task_rep_comb.append(in_task_rep_comb)
559
        all_in_rep_comb.append(in_rep_comb)
560
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
561
    #all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
562
    all_in_rep_comb = torch.stack(all_in_rep_comb)
563

564
    #cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
565
    cur_tensors = (all_in_rep_comb, all_sentence_binary_label)
566

567
    return cur_tensors
568

569

570

571

572
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
573
    '''
574
    top_k_shape = top_k.indices.shape
575
    sentence_ids = top_k.indices
576
    '''
577
    #top_k_shape = top_k["indices"].shape
578
    #sentence_ids = top_k["indices"]
579

580

581
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
582
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
583

584

585
    #uniqe_type_id = torch.LongTensor(list(set(sentence_label_.tolist())))
586

587
    all_sentence_binary_label = list()
588
    all_in_task_rep_comb = list()
589

590
    for id_1, num in enumerate(sentence_label_):
591
        #print([sentence_label_==num])
592
        #print(type([sentence_label_==num]))
593
        sentence_label_int = (sentence_label_==num).to(torch.long)
594
        #print(sentence_label_int)
595
        #print(sentence_label_int.shape)
596
        #print(in_task_rep[id_1].shape)
597
        #print(in_task_rep.shape)
598
        #exit()
599
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
600
        #print(in_task_rep_append)
601
        #print(in_task_rep_append.shape)
602
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
603
        #print(in_task_rep_comb)
604
        #print(in_task_rep_comb.shape)
605
        #exit()
606
        #sentence_label_int = sentence_label_int.to(torch.float32)
607
        #print(sentence_label_int)
608
        #exit()
609
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
610
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
611
        all_sentence_binary_label.append(sentence_label_int)
612
        all_in_task_rep_comb.append(in_task_rep_comb)
613
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
614
    all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
615

616
    cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
617

618
    return cur_tensors
619

620

621

622
'''
623
class Dataset_noNext(Dataset):
624
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
625

626
        self.vocab_size = tokenizer.vocab_size
627
        self.tokenizer = tokenizer
628
        self.seq_len = seq_len
629
        self.on_memory = on_memory
630
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
631
        self.corpus_path = corpus_path
632
        self.encoding = encoding
633
        self.current_doc = 0  # to avoid random sentence from same doc
634

635
        # for loading samples directly from file
636
        self.sample_counter = 0  # used to keep track of full epochs on file
637
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
638

639
        # for loading samples in memory
640
        self.current_random_doc = 0
641
        self.num_docs = 0
642
        self.sample_to_doc = [] # map sample index to doc and line
643

644
        # load samples into memory
645
        if on_memory:
646
            self.all_docs = []
647
            doc = []
648
            self.corpus_lines = 0
649
            with open(corpus_path, "r", encoding=encoding) as f:
650
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
651
                    line = line.strip()
652
                    if line == "":
653
                        self.all_docs.append(doc)
654
                        doc = []
655
                        #remove last added sample because there won't be a subsequent line anymore in the doc
656
                        self.sample_to_doc.pop()
657
                    else:
658
                        #store as one sample
659
                        sample = {"doc_id": len(self.all_docs),
660
                                  "line": len(doc)}
661
                        self.sample_to_doc.append(sample)
662
                        doc.append(line)
663
                        self.corpus_lines = self.corpus_lines + 1
664

665
            # if last row in file is not empty
666
            if self.all_docs[-1] != doc:
667
                self.all_docs.append(doc)
668
                self.sample_to_doc.pop()
669

670
            self.num_docs = len(self.all_docs)
671

672
        # load samples later lazily from disk
673
        else:
674
            if self.corpus_lines is None:
675
                with open(corpus_path, "r", encoding=encoding) as f:
676
                    self.corpus_lines = 0
677
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
678
                        if line.strip() == "":
679
                            self.num_docs += 1
680
                        else:
681
                            self.corpus_lines += 1
682

683
                    # if doc does not end with empty line
684
                    if line.strip() != "":
685
                        self.num_docs += 1
686

687
            self.file = open(corpus_path, "r", encoding=encoding)
688
            self.random_file = open(corpus_path, "r", encoding=encoding)
689

690
    def __len__(self):
691
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
692
        return self.corpus_lines - self.num_docs - 1
693

694
    def __getitem__(self, item):
695
        cur_id = self.sample_counter
696
        self.sample_counter += 1
697
        if not self.on_memory:
698
            # after one epoch we start again from beginning of file
699
            if cur_id != 0 and (cur_id % len(self) == 0):
700
                self.file.close()
701
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
702

703
        #t1, t2, is_next_label = self.random_sent(item)
704
        t1, is_next_label = self.random_sent(item)
705
        if is_next_label == None:
706
            is_next_label = 0
707

708

709
        #tokens_a = self.tokenizer.tokenize(t1)
710
        tokens_a = tokenizer.tokenize(t1)
711
        #if "</s>" in tokens_a:
712
        #    print("Have more than 1 </s>")
713
        #    #tokens_a[tokens_a.index("<s>")] = "s"
714
        #    for i in range(len(tokens_a)):
715
        #        if tokens_a[i] == "</s>":
716
        #            tokens_a[i] = "s"
717
        #tokens_b = self.tokenizer.tokenize(t2)
718

719
        # tokenize
720
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
721

722
        # transform sample to features
723
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
724

725
        cur_tensors = (torch.tensor(cur_features.input_ids),
726
                       torch.tensor(cur_features.input_ids_org),
727
                       torch.tensor(cur_features.input_mask),
728
                       torch.tensor(cur_features.segment_ids),
729
                       torch.tensor(cur_features.lm_label_ids),
730
                       torch.tensor(cur_features.is_next),
731
                       torch.tensor(cur_features.tail_idxs))
732

733
        return cur_tensors
734

735
    def random_sent(self, index):
736
        """
737
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
738
        from one doc. With 50% the second sentence will be a random one from another doc.
739
        :param index: int, index of sample.
740
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
741
        """
742
        t1, t2 = self.get_corpus_line(index)
743
        return t1, None
744

745
    def get_corpus_line(self, item):
746
        """
747
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
748
        :param item: int, index of sample.
749
        :return: (str, str), two subsequent sentences from corpus
750
        """
751
        t1 = ""
752
        t2 = ""
753
        assert item < self.corpus_lines
754
        if self.on_memory:
755
            sample = self.sample_to_doc[item]
756
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
757
            # used later to avoid random nextSentence from same doc
758
            self.current_doc = sample["doc_id"]
759
            return t1, t2
760
            #return t1
761
        else:
762
            if self.line_buffer is None:
763
                # read first non-empty line of file
764
                while t1 == "" :
765
                    t1 = next(self.file).strip()
766
            else:
767
                # use t2 from previous iteration as new t1
768
                t1 = self.line_buffer
769
                # skip empty rows that are used for separating documents and keep track of current doc id
770
                while t1 == "":
771
                    t1 = next(self.file).strip()
772
                    self.current_doc = self.current_doc+1
773
            self.line_buffer = next(self.file).strip()
774

775
        assert t1 != ""
776
        return t1, t2
777

778

779
    def get_random_line(self):
780
        """
781
        Get random line from another document for nextSentence task.
782
        :return: str, content of one line
783
        """
784
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
785
        # corpora. However, just to be careful, we try to make sure that
786
        # the random document is not the same as the document we're processing.
787
        for _ in range(10):
788
            if self.on_memory:
789
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
790
                rand_doc = self.all_docs[rand_doc_idx]
791
                line = rand_doc[random.randrange(len(rand_doc))]
792
            else:
793
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
794
                #pick random line
795
                for _ in range(rand_index):
796
                    line = self.get_next_line()
797
            #check if our picked random line is really from another doc like we want it to be
798
            if self.current_random_doc != self.current_doc:
799
                break
800
        return line
801

802
    def get_next_line(self):
803
        """ Gets next line of random_file and starts over when reaching end of file"""
804
        try:
805
            line = next(self.random_file).strip()
806
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
807
            if line == "":
808
                self.current_random_doc = self.current_random_doc + 1
809
                line = next(self.random_file).strip()
810
        except StopIteration:
811
            self.random_file.close()
812
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
813
            line = next(self.random_file).strip()
814
        return line
815
'''
816

817

818
class InputExample(object):
819
    """A single training/test example for the language model."""
820

821
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
822
        """Constructs a InputExample.
823
        Args:
824
            guid: Unique id for the example.
825
            tokens_a: string. The untokenized text of the first sequence. For single
826
            sequence tasks, only this sequence must be specified.
827
            tokens_b: (Optional) string. The untokenized text of the second sequence.
828
            Only must be specified for sequence pair tasks.
829
            label: (Optional) string. The label of the example. This should be
830
            specified for train and dev examples, but not for test examples.
831
        """
832
        self.guid = guid
833
        self.tokens_a = tokens_a
834
        self.tokens_b = tokens_b
835
        self.is_next = is_next  # nextSentence
836
        self.lm_labels = lm_labels  # masked words for language model
837

838

839
class InputFeatures(object):
840
    """A single set of features of data."""
841

842
    def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
843
        self.input_ids = input_ids
844
        self.input_ids_org = input_ids_org
845
        self.input_mask = input_mask
846
        self.segment_ids = segment_ids
847
        self.is_next = is_next
848
        self.lm_label_ids = lm_label_ids
849
        self.tail_idxs = tail_idxs
850

851

852
def random_word(tokens, tokenizer):
853
    """
854
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
855
    :param tokens: list of str, tokenized sentence.
856
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
857
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
858
    """
859
    output_label = []
860

861
    for i, token in enumerate(tokens):
862

863
        prob = random.random()
864
        # mask token with 15% probability
865
        if prob < 0.15:
866
            prob /= 0.15
867
            #candidate_id = random.randint(0,tokenizer.vocab_size)
868
            #print(tokenizer.convert_ids_to_tokens(candidate_id))
869

870

871
            # 80% randomly change token to mask token
872
            if prob < 0.8:
873
                #tokens[i] = "[MASK]"
874
                tokens[i] = "<mask>"
875

876
            # 10% randomly change token to random token
877
            elif prob < 0.9:
878
                #tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
879
                #tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
880
                candidate_id = random.randint(0,tokenizer.vocab_size)
881
                w = tokenizer.convert_ids_to_tokens(candidate_id)
882
                '''
883
                if tokens[i] == None:
884
                    candidate_id = 100
885
                    w = tokenizer.convert_ids_to_tokens(candidate_id)
886
                '''
887
                tokens[i] = w
888

889

890
            # -> rest 10% randomly keep current token
891

892
            # append current token to output (we will predict these later)
893
            try:
894
                #output_label.append(tokenizer.vocab[token])
895
                w = tokenizer.convert_tokens_to_ids(token)
896
                if w!= None:
897
                    output_label.append(w)
898
                else:
899
                    print("Have no this tokens in ids")
900
                    exit()
901
            except KeyError:
902
                # For unknown words (should not occur with BPE vocab)
903
                #output_label.append(tokenizer.vocab["<unk>"])
904
                w = tokenizer.convert_tokens_to_ids("<unk>")
905
                output_label.append(w)
906
                logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
907
        else:
908
            # no masking token (will be ignored by loss function later)
909
            output_label.append(-1)
910

911
    return tokens, output_label
912

913

914
def convert_example_to_features(example, max_seq_length, tokenizer):
915
    """
916
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
917
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
918
    :param example: InputExample, containing sentence input as strings and is_next label
919
    :param max_seq_length: int, maximum length of sequence.
920
    :param tokenizer: Tokenizer
921
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
922
    """
923
    #now tokens_a is input_ids
924
    tokens_a = example.tokens_a
925
    tokens_b = example.tokens_b
926
    # Modifies `tokens_a` and `tokens_b` in place so that the total
927
    # length is less than the specified length.
928
    # Account for [CLS], [SEP], [SEP] with "- 3"
929
    #_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
930
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
931

932
    #print(tokens_a)
933
    tokens_a_org = tokens_a.copy()
934
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
935
    #print("----")
936
    #print(tokens_a)
937
    #print(tokens_a_org)
938
    #exit()
939
    #print(t1_label)
940
    #exit()
941
    #tokens_b, t2_label = random_word(tokens_b, tokenizer)
942
    # concatenate lm labels and account for CLS, SEP, SEP
943
    #lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
944
    lm_label_ids = ([-1] + t1_label + [-1])
945

946
    # The convention in BERT is:
947
    # (a) For sequence pairs:
948
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
949
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
950
    # (b) For single sequences:
951
    #  tokens:   [CLS] the dog is hairy . [SEP]
952
    #  type_ids: 0   0   0   0  0     0 0
953
    #
954
    # Where "type_ids" are used to indicate whether this is the first
955
    # sequence or the second sequence. The embedding vectors for `type=0` and
956
    # `type=1` were learned during pre-training and are added to the wordpiece
957
    # embedding vector (and position vector). This is not *strictly* necessary
958
    # since the [SEP] token unambigiously separates the sequences, but it makes
959
    # it easier for the model to learn the concept of sequences.
960
    #
961
    # For classification tasks, the first vector (corresponding to [CLS]) is
962
    # used as as the "sentence vector". Note that this only makes sense because
963
    # the entire model is fine-tuned.
964
    tokens = []
965
    tokens_org = []
966
    segment_ids = []
967
    tokens.append("<s>")
968
    tokens_org.append("<s>")
969
    segment_ids.append(0)
970
    for i, token in enumerate(tokens_a):
971
        if token!="</s>":
972
            tokens.append(tokens_a[i])
973
            tokens_org.append(tokens_a_org[i])
974
            segment_ids.append(0)
975
        else:
976
            tokens.append("s")
977
            tokens_org.append("s")
978
            segment_ids.append(0)
979
    tokens.append("</s>")
980
    tokens_org.append("</s>")
981
    segment_ids.append(0)
982

983
    #tokens.append("[SEP]")
984
    #segment_ids.append(1)
985

986
    #input_ids = tokenizer.convert_tokens_to_ids(tokens)
987
    input_ids = tokenizer.encode(tokens, add_special_tokens=False)
988
    input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
989
    tail_idxs = len(input_ids)-1
990

991
    #print(input_ids)
992
    input_ids = [w if w!=None else 0 for w in input_ids]
993
    input_ids_org = [w if w!=None else 0 for w in input_ids_org]
994
    #print(input_ids)
995
    #exit()
996

997
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
998
    # tokens are attended to.
999
    input_mask = [1] * len(input_ids)
1000

1001
    # Zero-pad up to the sequence length.
1002
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1003
    while len(input_ids) < max_seq_length:
1004
        input_ids.append(pad_id)
1005
        input_ids_org.append(pad_id)
1006
        input_mask.append(0)
1007
        segment_ids.append(0)
1008
        lm_label_ids.append(-1)
1009

1010
    try:
1011
        assert len(input_ids) == max_seq_length
1012
        assert len(input_ids_org) == max_seq_length
1013
        assert len(input_mask) == max_seq_length
1014
        assert len(segment_ids) == max_seq_length
1015
        assert len(lm_label_ids) == max_seq_length
1016
    except:
1017
        print("!!!Warning!!!")
1018
        input_ids = input_ids[:max_seq_length-1]
1019
        if 2 not in input_ids:
1020
            input_ids += [2]
1021
        else:
1022
            input_ids += [pad_id]
1023
        input_ids_org = input_ids_org[:max_seq_length-1]
1024
        if 2 not in input_ids_org:
1025
            input_ids_org += [2]
1026
        else:
1027
            input_ids_org += [pad_id]
1028
        input_mask = input_mask[:max_seq_length-1]+[0]
1029
        segment_ids = segment_ids[:max_seq_length-1]+[0]
1030
        lm_label_ids = lm_label_ids[:max_seq_length-1]+[-1]
1031
    '''
1032
    flag=False
1033
    if len(input_ids) != max_seq_length:
1034
        print(len(input_ids))
1035
        flag=True
1036
    if len(input_ids_org) != max_seq_length:
1037
        print(len(input_ids_org))
1038
        flag=True
1039
    if len(input_mask) != max_seq_length:
1040
        print(len(input_mask))
1041
        flag=True
1042
    if len(segment_ids) != max_seq_length:
1043
        print(len(segment_ids))
1044
        flag=True
1045
    if len(lm_label_ids) != max_seq_length:
1046
        print(len(lm_label_ids))
1047
        flag=True
1048
    if flag == True:
1049
        print("1165")
1050
        exit()
1051
    '''
1052

1053
    '''
1054
    if example.guid < 5:
1055
        logger.info("*** Example ***")
1056
        logger.info("guid: %s" % (example.guid))
1057
        logger.info("tokens: %s" % " ".join(
1058
                [str(x) for x in tokens]))
1059
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1060
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1061
        logger.info(
1062
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1063
        logger.info("LM label: %s " % (lm_label_ids))
1064
        logger.info("Is next sentence label: %s " % (example.is_next))
1065
    '''
1066

1067
    features = InputFeatures(input_ids=input_ids,
1068
                             input_ids_org = input_ids_org,
1069
                             input_mask=input_mask,
1070
                             segment_ids=segment_ids,
1071
                             lm_label_ids=lm_label_ids,
1072
                             is_next=example.is_next,
1073
                             tail_idxs=tail_idxs)
1074
    return features
1075

1076

1077
def main():
1078
    parser = argparse.ArgumentParser()
1079

1080
    parser = get_parameter(parser)
1081

1082
    args = parser.parse_args()
1083

1084
    if args.local_rank == -1 or args.no_cuda:
1085
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1086
        n_gpu = torch.cuda.device_count()
1087
    else:
1088
        torch.cuda.set_device(args.local_rank)
1089
        device = torch.device("cuda", args.local_rank)
1090
        n_gpu = 1
1091
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1092
        torch.distributed.init_process_group(backend='nccl')
1093
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1094
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
1095

1096
    if args.gradient_accumulation_steps < 1:
1097
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1098
                            args.gradient_accumulation_steps))
1099

1100
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1101

1102
    random.seed(args.seed)
1103
    np.random.seed(args.seed)
1104
    torch.manual_seed(args.seed)
1105
    if n_gpu > 0:
1106
        torch.cuda.manual_seed_all(args.seed)
1107

1108
    if not args.do_train:
1109
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1110

1111
    #if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1112
    #    raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1113
    #if not os.path.exists(args.output_dir):
1114
    #    os.makedirs(args.output_dir)
1115

1116
    #tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1117
    tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1118

1119

1120
    #train_examples = None
1121
    num_train_optimization_steps = None
1122
    if args.do_train:
1123
        print("Loading Train Dataset", args.data_dir_indomain)
1124
        #train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1125
        all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1126
        num_train_optimization_steps = int(
1127
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1128
        if args.local_rank != -1:
1129
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1130

1131

1132

1133
    # Prepare model
1134
    model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1135
    #model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1136
    model.to(device)
1137

1138

1139

1140
    # Prepare optimizer
1141
    if args.do_train:
1142
        param_optimizer = list(model.named_parameters())
1143
        '''
1144
        for par in param_optimizer:
1145
            print(par[0])
1146
        exit()
1147
        '''
1148
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1149
        optimizer_grouped_parameters = [
1150
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1151
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1152
            ]
1153
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1154
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1155

1156
        if args.fp16:
1157
            try:
1158
                from apex import amp
1159
            except ImportError:
1160
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1161
                exit()
1162

1163
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1164

1165

1166
        if n_gpu > 1:
1167
            model = torch.nn.DataParallel(model)
1168

1169
        if args.local_rank != -1:
1170
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1171

1172

1173

1174
    global_step = 0
1175
    if args.do_train:
1176
        logger.info("***** Running training *****")
1177
        logger.info("  Num examples = %d", len(train_dataset))
1178
        logger.info("  Batch size = %d", args.train_batch_size)
1179
        logger.info("  Num steps = %d", num_train_optimization_steps)
1180

1181
        if args.local_rank == -1:
1182
            train_sampler = RandomSampler(train_dataset)
1183
            #all_type_sentence_sampler = RandomSampler(all_type_sentence)
1184
        else:
1185
            #TODO: check if this works with current data generator from disk that relies on next(file)
1186
            # (it doesn't return item back by index)
1187
            train_sampler = DistributedSampler(train_dataset)
1188
            #all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1189
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1190
        #all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1191

1192
        output_loss_file = os.path.join(args.output_dir, "loss")
1193
        loss_fout = open(output_loss_file, 'w')
1194

1195

1196
        output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1197
        loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1198
        model.train()
1199

1200

1201

1202

1203
        #alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1204
        #alpha = float(1/args.num_train_epochs)
1205
        alpha = float(1)
1206
        k=32
1207
        #k=64 #org --> use
1208
        #k=128
1209
        choose_n=32
1210
        no_tune = -1
1211
        #k=16
1212
        #k = args.K
1213
        #k = 10
1214
        #k = 2
1215
        #retrive_gate = args.num_labels_task
1216
        #retrive_gate = len(train_dataset)/100
1217
        retrive_gate = 1
1218
        all_type_sentence_label = list()
1219
        all_previous_sentence_label = list()
1220
        all_type_sentiment_label = list()
1221
        all_previous_sentiment_label = list()
1222
        top_k_all_type = dict()
1223
        bottom_k_all_type = dict()
1224

1225
        #show_type=[0,0,1,1,2,2] #3type
1226
        show_type=[0,0] #3type
1227
        for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1228
            tr_loss = 0
1229
            nb_tr_examples, nb_tr_steps = 0, 0
1230
            for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1231

1232

1233
                #######################
1234
                ######################
1235
                ###Init 8 type sentence
1236
                ###Init 2 type sentiment
1237
                if (step == 0) and (epo == 0):
1238
                    #batch_ = tuple(t.to(device) for t in batch_)
1239
                    #all_type_sentence_ = tuple(t.to(device) for t in all_type_sentence)
1240

1241
                    input_ids_ = torch.stack([line[0] for line in all_type_sentence]).to(device)
1242
                    input_ids_org_ = torch.stack([line[1] for line in all_type_sentence]).to(device)
1243
                    input_mask_ = torch.stack([line[2] for line in all_type_sentence]).to(device)
1244
                    segment_ids_ = torch.stack([line[3] for line in all_type_sentence]).to(device)
1245
                    lm_label_ids_ = torch.stack([line[4] for line in all_type_sentence]).to(device)
1246
                    is_next_ = torch.stack([line[5] for line in all_type_sentence]).to(device)
1247
                    tail_idxs_ = torch.stack([line[6] for line in all_type_sentence]).to(device)
1248
                    sentence_label_ = torch.stack([line[7] for line in all_type_sentence]).to(device)
1249
                    sentiment_label_ = torch.stack([line[8] for line in all_type_sentence]).to(device)
1250

1251
                    with torch.no_grad():
1252

1253
                        #in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1254
                        in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1255
                        # Search id from Docs and ranking via (Domain/Task)
1256
                        #query_domain = in_domain_rep_mean.float().to("cpu")
1257
                        query_domain = in_domain_rep.float().to("cpu")
1258
                        query_domain = query_domain.unsqueeze(1)
1259
                        #query_task = in_task_rep_mean.float().to("cpu")
1260
                        query_task = in_task_rep.float().to("cpu")
1261
                        query_task = query_task.unsqueeze(1)
1262
                        #query_domain_task = torch.cat([query_domain,query_task],2)
1263

1264

1265
                        task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1266
                        task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1267
                        task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1268
                        task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1269

1270

1271
                        domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1272
                        domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1273
                        domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1274
                        domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1275

1276

1277
                        #domain_task_binary_classifier_weight, domain_task_binary_classifier_bias = model(func="return_domain_task_binary_classifier")
1278
                        #domain_task_binary_classifier_weight = domain_task_binary_classifier_weight[:int(domain_task_binary_classifier_weight.shape[0]/n_gpu)][:]
1279
                        #domain_task_binary_classifier_bias = domain_task_binary_classifier_bias[:int(domain_task_binary_classifier_bias.shape[0]/n_gpu)][:]
1280
                        #domain_task_binary_classifier = return_Classifier(domain_task_binary_classifier_weight, domain_task_binary_classifier_bias, 768*4, 2)
1281

1282
                        #start = time.time()
1283
                        query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1284
                        query_task = query_task.expand(-1, docs_head.shape[0], -1)
1285
                        #query_domain_task = query_domain_task.expand(-1, docs_head.shape[0], -1)
1286

1287
                        #################
1288
                        #################
1289
                        #Ranking
1290

1291
                        #LeakyReLU = torch.nn.LeakyReLU()
1292
                        #Domain logit
1293
                        '''
1294
                        domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1295
                        domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1296
                        domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1297
                        '''
1298
                        domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1299
                        target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1300
                        #domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1301
                        domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1302

1303
                        #Task logit
1304
                        task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1305
                        #task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1306
                        #target = torch.zeros(task_binary_logit.shape[0], task_binary_logit.shape[1], dtype=torch.long)
1307
                        task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1308

1309
                        #Domain Task logit
1310
                        #domain_task_binary_logit = task_binary_logit+domain_binary_logit*0.5
1311
                        domain_task_binary_logit = task_binary_logit
1312

1313
                        ### For paper
1314
                        '''
1315
                        ###Domain
1316
                        ######
1317
                        domain_top_k_all_type = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1318
                        perm = torch.randperm(domain_binary_logit.shape[1])
1319
                        domain_bottom_k_all_type_indices = perm[:k]
1320
                        domain_bottom_k_all_type_values = domain_binary_logit[:,domain_bottom_k_all_type_indices]
1321
                        domain_bottom_k_all_type_indices = torch.stack(args.domain_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1322

1323

1324
                        ####Task
1325
                        task_top_k_all_type = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1326
                        ###Domain+Task
1327
                        domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1328
                        '''
1329
                        ###
1330
                        ###########################
1331
                        ###Performance
1332
                        ###Domain
1333
                        ######
1334
                        domain_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1335
                        ###
1336
                        rand_seed = torch.randint(0,k,(choose_n,))
1337
                        domain_top_k_all_type_indices = domain_top_k_all_type.indices[:,rand_seed]
1338
                        domain_top_k_all_type_values = domain_top_k_all_type.values[:,rand_seed]
1339
                        ###
1340

1341

1342
                        #perm = torch.randperm(domain_task_binary_logit.shape[1])
1343
                        #domain_bottom_k_all_type_indices = perm[:k]
1344
                        #domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1345
                        #domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1346

1347
                        #domain_bottom_k_all_type = torch.topk(domain_task_binary_logit, k*2, dim=1, largest=False, sorted=False)
1348
                        domain_bottom_k_all_type_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1349
                        domain_bottom_k_all_type_values = domain_task_binary_logit[:,domain_bottom_k_all_type_indices]
1350
                        domain_bottom_k_all_type_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_all_type_indices])
1351

1352

1353
                        ####Task
1354
                        task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1355
                        ###
1356
                        rand_seed = torch.randint(0,k,(choose_n,))
1357
                        task_top_k_all_type_indices = task_top_k_all_type.indices[:,rand_seed]
1358
                        task_top_k_all_type_values = task_top_k_all_type.values[:,rand_seed]
1359
                        ###
1360

1361

1362
                        ###Domain+Task
1363
                        domain_task_top_k_all_type = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1364
                        ###
1365
                        rand_seed = torch.randint(0,k,(choose_n,))
1366
                        domain_task_top_k_all_type_indices = domain_task_top_k_all_type.indices[:,rand_seed]
1367
                        domain_task_top_k_all_type_values = domain_task_top_k_all_type.values[:,rand_seed]
1368
                        ###
1369

1370

1371
                        ###########################
1372

1373

1374
                        del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1375

1376
                        all_type_sentiment_label = sentiment_label_.to('cpu')
1377

1378

1379
                        domain_bottom_k_all_type = {"values":domain_bottom_k_all_type_values, "indices":domain_bottom_k_all_type_indices}
1380
                        #domain_top_k_all_type = {"values":domain_top_k_all_type.values, "indices":domain_top_k_all_type.indices}
1381
                        domain_top_k_all_type = {"values":domain_top_k_all_type_values, "indices":domain_top_k_all_type_indices}
1382
                        #task_top_k_all_type = {"values":task_top_k_all_type.values, "indices":task_top_k_all_type.indices}
1383
                        task_top_k_all_type = {"values":task_top_k_all_type_values, "indices":task_top_k_all_type_indices}
1384
                        #domain_task_top_k_all_type = {"values":domain_task_top_k_all_type.values, "indices":domain_task_top_k_all_type.indices}
1385
                        domain_task_top_k_all_type = {"values":domain_task_top_k_all_type_values, "indices":domain_task_top_k_all_type_indices}
1386

1387
                ######################
1388
                ######################
1389

1390

1391
                ###Normal mode
1392
                batch_ = tuple(t.to(device) for t in batch_)
1393
                input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1394

1395

1396
                ###
1397
                # Generate query representation
1398
                in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1399

1400

1401
                #if (step%10 == 0) or (sentence_label_.shape[0] != args.train_batch_size):
1402
                if (step%retrive_gate == 0) or (sentiment_label_.shape[0] != args.train_batch_size):
1403

1404
                    with torch.no_grad():
1405
                        query_domain = in_domain_rep.float().to("cpu")
1406
                        query_domain = query_domain.unsqueeze(1)
1407
                        #query_task = in_task_rep_mean.float().to("cpu")
1408
                        query_task = in_task_rep.float().to("cpu")
1409
                        query_task = query_task.unsqueeze(1)
1410
                        query_domain_task = torch.cat([query_domain,query_task],2)
1411

1412

1413
                        task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1414
                        task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1415
                        task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1416
                        task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1417

1418

1419
                        domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1420
                        domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1421
                        domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1422
                        domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768*2, 2)
1423

1424

1425
                        #domain_task_binary_classifier_weight, domain_task_binary_classifier_bias = model(func="return_domain_task_binary_classifier")
1426
                        #domain_task_binary_classifier_weight = domain_task_binary_classifier_weight[:int(domain_task_binary_classifier_weight.shape[0]/n_gpu)][:]
1427
                        #domain_task_binary_classifier_bias = domain_task_binary_classifier_bias[:int(domain_task_binary_classifier_bias.shape[0]/n_gpu)][:]
1428
                        #domain_task_binary_classifier = return_Classifier(domain_task_binary_classifier_weight, domain_task_binary_classifier_bias, 768*4, 2)
1429

1430
                        #start = time.time()
1431
                        #query_domain = query_domain.expand(-1, docs.shape[0], -1)
1432
                        query_domain = query_domain.expand(-1, docs_tail.shape[0], -1)
1433
                        #query_task = query_task.expand(-1, docs.shape[0], -1)
1434
                        query_task = query_task.expand(-1, docs_head.shape[0], -1)
1435
                        #print(docs_head.shape)
1436
                        #print(query_domain_task.shape)
1437
                        #exit()
1438
                        #query_domain_task = query_domain_task.expand(-1, docs_head.shape[0], -1)
1439

1440
                        #################
1441
                        #################
1442
                        #Ranking
1443

1444
                        #LeakyReLU = torch.nn.LeakyReLU()
1445
                        #Domain logit
1446
                        '''
1447
                        domain_binary_logit = LeakyReLU(domain_binary_classifier(docs_tail))
1448
                        domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1449
                        domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1450
                        '''
1451
                        domain_binary_logit = domain_binary_classifier(torch.cat([query_domain, docs_tail[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1452
                        target = torch.zeros(domain_binary_logit.shape[0], domain_binary_logit.shape[1], dtype=torch.long)
1453
                        #domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1454
                        domain_binary_logit = ce_loss(domain_binary_logit.view(-1, 2), target.view(-1)).reshape(domain_binary_logit.shape[0],domain_binary_logit.shape[1])
1455

1456
                        #Task logit
1457
                        task_binary_logit = task_binary_classifier(torch.cat([query_task, docs_head[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2))
1458
                        #task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1459
                        task_binary_logit = ce_loss(task_binary_logit.view(-1, 2), target.view(-1)).reshape(task_binary_logit.shape[0],task_binary_logit.shape[1])
1460

1461
                        #Domain Task logit
1462
                        domain_task_binary_logit = task_binary_logit + domain_binary_logit*0.5
1463

1464
                        ####################
1465
                        ###paper
1466
                        ###Domaine
1467
                        ######
1468
                        #[batch_size, 36603]
1469
                        domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=True, sorted=False)
1470
                        ###
1471
                        rand_seed = torch.randint(0,k,(choose_n,))
1472
                        domain_top_k_indices = domain_top_k.indices[:,rand_seed]
1473
                        domain_top_k_values = domain_top_k.values[:,rand_seed]
1474
                        ###
1475

1476
                        '''
1477
                        perm = torch.randperm(domain_binary_logit.shape[1])
1478
                        domain_bottom_k_indices = perm[:k]
1479
                        domain_bottom_k_values = domain_binary_logit[:,domain_bottom_k_indices]
1480
                        domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1481
                        '''
1482

1483
                        #domain_top_k = torch.topk(domain_binary_logit, k, dim=1, largest=False, sorted=False)
1484
                        domain_bottom_k_indices = torch.randint(k+1,domain_binary_logit.shape[1],(choose_n*2,))
1485
                        domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1486
                        domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1487

1488

1489
                        task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1490
                        ###
1491
                        #rand_seed = torch.randint(0,k,(choose_n,))
1492
                        task_top_k_indices = task_top_k.indices[:,rand_seed]
1493
                        task_top_k_values = task_top_k.values[:,rand_seed]
1494
                        ###
1495

1496

1497
                        domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1498
                        #rand_seed = torch.randint(0,k,(choose_n,))
1499
                        domain_task_top_k_indices = domain_task_top_k.indices[:,rand_seed]
1500
                        domain_task_top_k_values = domain_task_top_k.values[:,rand_seed]
1501

1502

1503
                        ####################
1504
                        '''
1505
                        ###Performance
1506
                        domain_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1507
                        perm = torch.randperm(domain_task_binary_logit.shape[1])
1508
                        domain_bottom_k_indices = perm[:k]
1509
                        domain_bottom_k_values = domain_task_binary_logit[:,domain_bottom_k_indices]
1510
                        domain_bottom_k_indices = torch.stack(domain_task_binary_logit.shape[0]*[domain_bottom_k_indices])
1511
                        task_top_k = torch.topk(task_binary_logit, k, dim=1, largest=True, sorted=False)
1512
                        domain_task_top_k = torch.topk(domain_task_binary_logit, k, dim=1, largest=True, sorted=False)
1513
                        '''
1514
                        ####################
1515

1516

1517
                        del domain_task_binary_logit, domain_binary_logit, task_binary_logit
1518

1519
                        all_previous_sentiment_label = sentiment_label_.to('cpu')
1520

1521
                        ######
1522

1523

1524
                        domain_bottom_k = {"values":domain_bottom_k_values, "indices":domain_bottom_k_indices}
1525
                        #domain_top_k = {"values":domain_top_k.values, "indices":domain_top_k.indices}
1526
                        domain_top_k = {"values":domain_top_k_values, "indices":domain_top_k_indices}
1527
                        #task_top_k = {"values":task_top_k.values, "indices":task_top_k.indices}
1528
                        task_top_k = {"values":task_top_k_values, "indices":task_top_k_indices}
1529
                        #domain_task_top_k = {"values":domain_task_top_k.values, "indices":domain_task_top_k.indices}
1530
                        domain_task_top_k = {"values":domain_task_top_k_values, "indices":domain_task_top_k_indices}
1531

1532

1533

1534

1535
                        domain_bottom_k_previous = {"values":torch.cat((domain_bottom_k["values"], domain_bottom_k_all_type["values"]),0), "indices":torch.cat((domain_bottom_k["indices"], domain_bottom_k_all_type["indices"]),0)}
1536
                        domain_top_k_previous = {"values":torch.cat((domain_top_k["values"], domain_top_k_all_type["values"]),0), "indices":torch.cat((domain_top_k["indices"], domain_top_k_all_type["indices"]),0)}
1537
                        task_top_k_previous = {"values":torch.cat((task_top_k["values"], task_top_k_all_type["values"]),0), "indices":torch.cat((task_top_k["indices"], task_top_k_all_type["indices"]),0)}
1538
                        domain_task_top_k_previous = {"values":torch.cat((domain_task_top_k["values"], domain_task_top_k_all_type["values"]),0), "indices":torch.cat((domain_task_top_k["indices"], domain_task_top_k_all_type["indices"]),0)}
1539

1540
                        all_previous_sentiment_label = torch.cat((all_previous_sentiment_label, all_type_sentiment_label))
1541
                else:
1542
                    ###Need to fix --> choice
1543
                    used_idx = torch.tensor([random.choice(((all_previous_sentiment_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentiment_label_])
1544
                    #top_k = {"values":top_k_previous["values"].index_select(0,used_idx), "indices":top_k_previous["indices"].index_select(0,used_idx)}
1545
                    domain_top_k = {"values":domain_top_k_previous["values"].index_select(0,used_idx), "indices":domain_top_k_previous["indices"].index_select(0,used_idx)}
1546
                    task_top_k = {"values":task_top_k_previous["values"].index_select(0,used_idx), "indices":task-top_k_previous["indices"].index_select(0,used_idx)}
1547
                    domain_task_top_k = {"values":domain_task_top_k_previous["values"].index_select(0,used_idx), "indices":domain_task_top_k_previous["indices"].index_select(0,used_idx)}
1548

1549
                    #bottom_k = {"values":bottom_k_previous["values"].index_select(0,used_idx), "indices":bottom_k_previous["indices"].index_select(0,used_idx)}
1550
                    domaion_bottom_k = {"values":domain_bottom_k_previous["values"].index_select(0,used_idx), "indices":domain_bottom_k_previous["indices"].index_select(0,used_idx)}
1551

1552

1553

1554

1555

1556
                if epo < no_tune:
1557

1558
                    #################
1559
                    #################
1560
                    #Domain Binary Classifier - Outdomain
1561
                    #batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1562
                    batch = AugmentationData_Domain(domain_top_k, domain_bottom_k, tokenizer, args.max_seq_length)
1563
                    batch = tuple(t.to(device) for t in batch)
1564
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, domain_id = batch
1565

1566
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, lm_label=lm_label_ids, attention_mask=input_mask, func="in_domain_task_rep")
1567
                    #print("======")
1568
                    #print(domain_top_k["indices"].shape)
1569
                    #print(input_ids_org.shape)
1570
                    #print(out_domain_rep_tail.shape)
1571
                    #print(in_domain_rep.shape)
1572
                    #print("======")
1573
                    ############Construct constrive instances
1574
                    comb_rep_pos = torch.cat([in_domain_rep,in_domain_rep.flip(0)], 1)
1575
                    in_domain_rep_ready = in_domain_rep.repeat(1,int(out_domain_rep_tail.shape[0]/in_domain_rep.shape[0])).reshape(out_domain_rep_tail.shape[0],out_domain_rep_tail.shape[1])
1576
                    comb_rep_unknow = torch.cat([in_domain_rep_ready, out_domain_rep_tail], 1)
1577

1578
                    mix_domain_binary_loss, domain_binary_logit = model(func="domain_binary_classifier", in_domain_rep=comb_rep_pos.to(device), out_domain_rep=comb_rep_unknow.to(device), domain_id=domain_id, use_detach=False)
1579
                    ############
1580

1581

1582
                    #################
1583
                    #################
1584
                    ###Update_rep
1585
                    indices = domain_top_k["indices"].reshape(domain_top_k["indices"].shape[0]*domain_top_k["indices"].shape[1])
1586
                    indices_ = domain_bottom_k["indices"].reshape(domain_bottom_k["indices"].shape[0]*domain_bottom_k["indices"].shape[1])
1587
                    indices = torch.cat([indices,indices_],0)
1588

1589
                    out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1590
                    out_domain_rep_head.requires_grad=True
1591

1592
                    out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1593
                    out_domain_rep_tail.requires_grad=True
1594

1595

1596
                    with torch.no_grad():
1597
                        #Exam here!!!
1598
                        try:
1599
                            docs_head.index_copy_(0, indices, out_domain_rep_head)
1600
                            docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1601
                        except:
1602
                            print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
1603
                            print("head",out_domain_rep_head.shape)
1604
                            print("tail",out_domain_rep_head.shape)
1605
                            print("doc_h",docs_head.shape)
1606
                            print("doc_t",docs_tail.shape)
1607
                            print("ind",indices.shape)
1608

1609

1610

1611
                    #################
1612
                    #################
1613
                    #Task Binary Classifier    in domain
1614
                    #Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1615
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1616

1617
                    batch = tuple(t.to(device) for t in batch)
1618
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1619
                    in_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1620

1621

1622
                    #################
1623
                    #################
1624
                    #Train Task org - finetune
1625
                    #split into: in_dom and query_  --> different weight
1626
                    task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1627

1628

1629
                    #################
1630
                    #################
1631
                    #Task Level   including outdomain
1632
                    batch = AugmentationData_Task(task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1633
                    batch = tuple(t.to(device) for t in batch)
1634
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1635
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1636
                    ###
1637
                    batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head)
1638
                    batch = tuple(t.to(device) for t in batch)
1639
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1640
                    out_task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier", use_detach=False)
1641
                    ###
1642

1643
                    #################
1644
                    #################
1645
                    ###Update_rep
1646
                    indices = task_top_k["indices"].reshape(task_top_k["indices"].shape[0]*task_top_k["indices"].shape[1])
1647

1648
                    out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1649
                    out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1650
                    out_domain_rep_head.requires_grad=True
1651

1652
                    out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1653
                    out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1654
                    out_domain_rep_tail.requires_grad=True
1655

1656
                    with torch.no_grad():
1657
                        try:
1658
                            docs_head.index_copy_(0, indices, out_domain_rep_head)
1659
                            docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1660
                        except:
1661
                            print("head",out_domain_rep_head.shape)
1662
                            print("head",out_domain_rep_head.get_device())
1663
                            print("tail",out_domain_rep_head.shape)
1664
                            print("tail",out_domain_rep_head.get_device())
1665
                            print("doc_h",docs_head.shape)
1666
                            print("doc_h",docs_head.get_device())
1667
                            print("doc_t",docs_tail.shape)
1668
                            print("doc_t",docs_tail.get_device())
1669
                            print("ind",indices.shape)
1670
                            print("ind",indices.get_device())
1671

1672
                    ##############################
1673
                    ##############################
1674

1675
                    #################
1676
                    #################
1677
                    #Domain-Task Level (Out-domain)
1678
                    batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_)
1679
                    batch = tuple(t.to(device) for t in batch)
1680
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1681
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1682
                    ###
1683
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1684
                    batch = tuple(t.to(device) for t in batch)
1685
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1686
                    out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1687
                    ###
1688

1689

1690
                    #Domain-Task Level (in-domain)
1691
                    ###
1692
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1693
                    batch = tuple(t.to(device) for t in batch)
1694
                    in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1695
                    in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1696
                    ###
1697

1698

1699
                    #################
1700
                    #################
1701
                    ###Update_rep
1702
                    indices = domain_task_top_k["indices"].reshape(domain_task_top_k["indices"].shape[0]*domain_task_top_k["indices"].shape[1])
1703

1704
                    out_domain_rep_head = out_domain_rep_head[input_ids_org_.shape[0]:,:]
1705
                    out_domain_rep_head = out_domain_rep_head.reshape(out_domain_rep_head.shape[0],1,out_domain_rep_head.shape[1]).to("cpu").data
1706
                    out_domain_rep_head.requires_grad=True
1707

1708
                    out_domain_rep_tail = out_domain_rep_tail[input_ids_org_.shape[0]:,:]
1709
                    out_domain_rep_tail = out_domain_rep_tail.reshape(out_domain_rep_tail.shape[0],1,out_domain_rep_tail.shape[1]).to("cpu").data
1710
                    out_domain_rep_tail.requires_grad=True
1711

1712
                    with torch.no_grad():
1713
                        try:
1714
                            docs_head.index_copy_(0, indices, out_domain_rep_head)
1715
                            docs_tail.index_copy_(0, indices, out_domain_rep_tail)
1716
                        except:
1717
                            print("head",out_domain_rep_head.shape)
1718
                            print("head",out_domain_rep_head.get_device())
1719
                            print("tail",out_domain_rep_head.shape)
1720
                            print("tail",out_domain_rep_head.get_device())
1721
                            print("doc_h",docs_head.shape)
1722
                            print("doc_h",docs_head.get_device())
1723
                            print("doc_t",docs_tail.shape)
1724
                            print("doc_t",docs_tail.get_device())
1725
                            print("ind",indices.shape)
1726
                            print("ind",indices.get_device())
1727

1728
                    ##############################
1729
                    ##############################
1730
                else:
1731

1732
                    #################
1733
                    #Domain-Task Level (Out-domain)
1734
                    batch = AugmentationData_Task(domain_task_top_k, tokenizer, args.max_seq_length, add_org=batch_, show_type=show_type)
1735

1736
                    _, input_ids_org11, _, _, _, _, _, _, sentiment_label11 = batch_
1737
                    for indexxx, label in enumerate(sentiment_label11):
1738
                        if int(label) in show_type and list(input_ids_org11[indexxx]).index(2)>20:
1739
                            #print(int(label))
1740
                            #show_type.pop(show_type(int(label)))
1741
                            if list(input_ids_org11[indexxx]).index(2)>10:
1742
                                show_type.remove(int(label))
1743
                            else:
1744
                                pass
1745
                            if len(show_type)==0:
1746
                                exit()
1747

1748

1749

1750
                    batch = tuple(t.to(device) for t in batch)
1751
                    input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1752
                    out_domain_rep_tail, out_domain_rep_head = model(input_ids_org=input_ids_org, tail_idxs=tail_idxs, attention_mask=input_mask, func="in_domain_task_rep")
1753
                    ###
1754
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch, in_task_rep=out_domain_rep_head, in_domain_rep=out_domain_rep_tail)
1755
                    batch = tuple(t.to(device) for t in batch)
1756
                    all_in_task_rep_comb, all_sentence_binary_label = batch
1757
                    out_domain_task_binary_loss, domain_task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="domain_task_binary_classifier")
1758
                    ###
1759

1760
                    batch = AugmentationData_Task_pos_and_neg_DT(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep, in_domain_rep=in_domain_rep)
1761
                    batch = tuple(t.to(device) for t in batch)
1762
                    in_all_in_task_rep_comb, in_all_sentence_binary_label = batch
1763
                    in_domain_task_binary_loss, in_domain_task_binary_logit = model(all_in_task_rep_comb=in_all_in_task_rep_comb, all_sentence_binary_label=in_all_sentence_binary_label, func="domain_task_binary_classifier")
1764
                    ###
1765

1766
                    ###Finetune
1767
                    task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1768

1769

1770

1771

1772

1773
                ############################################
1774
                ############################################
1775
                if epo < no_tune:
1776
                    if n_gpu > 1:
1777
                        loss = mix_domain_binary_loss.mean()*0.5 + (in_task_binary_loss.mean() + out_task_binary_loss.mean())*0.5 + task_loss_org.mean() + out_domain_task_binary_loss
1778
                    else:
1779
                        #loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1780
                        print("No Using GPU")
1781
                else:
1782
                    if n_gpu > 1:
1783
                        loss = task_loss_org.mean() + (in_domain_task_binary_loss.mean()+out_domain_task_binary_loss.mean())/2
1784
                    else:
1785
                        #loss = mix_domain_binary_loss + (in_task_binary_loss + out_task_binary_loss)/2 + task_loss_org + out_domain_task_binary_loss
1786
                        print("No Using GPU")
1787

1788

1789
                if args.gradient_accumulation_steps > 1:
1790
                    loss = loss / args.gradient_accumulation_steps
1791
                if args.fp16:
1792
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
1793
                        scaled_loss.backward()
1794
                else:
1795
                    loss.backward()
1796

1797
                ###
1798
                loss_fout.write("{}\n".format(loss.item()))
1799
                ###
1800

1801
                ###
1802
                #loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1803
                ###
1804

1805
                tr_loss += loss.item()
1806
                #nb_tr_examples += input_ids.size(0)
1807
                nb_tr_examples += input_ids_.size(0)
1808
                nb_tr_steps += 1
1809
                if (step + 1) % args.gradient_accumulation_steps == 0:
1810
                    if args.fp16:
1811
                        # modify learning rate with special warm up BERT uses
1812
                        # if args.fp16 is False, BertAdam is used that handles this automatically
1813
                        #lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1814
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1815
                    ###
1816
                    else:
1817
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1818
                    ###
1819

1820
                    optimizer.step()
1821
                    ###
1822
                    scheduler.step()
1823
                    ###
1824
                    #optimizer.zero_grad()
1825
                    model.zero_grad()
1826
                    global_step += 1
1827

1828

1829
            if epo < -1:
1830
                continue
1831
            else:
1832
                #model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1833
                #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1834
                #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(epo))
1835
                #torch.save(model_to_save.state_dict(), output_model_file)
1836
                print("PASS")
1837

1838
        #loss_fout.close()
1839

1840
        # Save a trained model
1841
        #logger.info("** ** * Saving fine - tuned model ** ** * ")
1842
        #model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1843
        #output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1844
        #if args.do_train:
1845
        #    torch.save(model_to_save.state_dict(), output_model_file)
1846

1847

1848

1849
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1850
    """Truncates a sequence pair in place to the maximum length."""
1851

1852
    # This is a simple heuristic which will always truncate the longer sequence
1853
    # one token at a time. This makes more sense than truncating an equal percent
1854
    # of tokens from each, since if one sequence is very short then each token
1855
    # that's truncated likely contains more information than a longer sequence.
1856
    while True:
1857
        #total_length = len(tokens_a) + len(tokens_b)
1858
        total_length = len(tokens_a)
1859
        if total_length <= max_length:
1860
            break
1861
        else:
1862
            tokens_a.pop()
1863

1864

1865
def accuracy(out, labels):
1866
    outputs = np.argmax(out, axis=1)
1867
    return np.sum(outputs == labels)
1868

1869

1870
if __name__ == "__main__":
1871
    main()
1872

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.