CSS-LM

Форк
0
/
pretrain_roberta_including_Preprocess_DomainTask_sentiment_noaspect_mean.py 
1791 строка · 77.6 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25
import json
26
import time
27

28
import numpy as np
29
import torch
30
from torch.utils.data import DataLoader, Dataset, RandomSampler
31
from torch.utils.data.distributed import DistributedSampler
32
from tqdm import tqdm, trange
33

34
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
35
from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
36
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
37

38
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
39
                    datefmt='%m/%d/%Y %H:%M:%S',
40
                    level=logging.INFO)
41
logger = logging.getLogger(__name__)
42

43

44
#def default_all_type_sentence(batch):
45

46

47

48
def return_Classifier(weight, bias, dim_in, dim_out):
49
    #LeakyReLU = torch.nn.LeakyReLU
50
    classifier = torch.nn.Linear(dim_in, dim_out , bias=True)
51
    #print(classifier)
52
    #print(classifier.weight)
53
    #print(classifier.weight.shape)
54
    #print(classifier.weight.data)
55
    #print(classifier.weight.data.shape)
56
    #print("---")
57
    classifier.weight.data = weight.to("cpu")
58
    classifier.bias.data = bias.to("cpu")
59
    classifier.requires_grad=False
60
    #print(classifier)
61
    #print(classifier.weight)
62
    #print(classifier.weight.shape)
63
    #print("---")
64
    #exit()
65
    #print(classifier)
66
    #exit()
67
    #logit = LeakyReLU(classifier)
68
    return classifier
69

70

71
def load_GeneralDomain(dir_data_out):
72
    ###Test
73
    if dir_data_out=="data/open_domain_preprocessed_roberta/":
74
        docs = torch.load(dir_data_out+"opendomain_CLS.pt")
75
        with open(dir_data_out+"opendomain.json") as file:
76
            data = json.load(file)
77
        print("train.json Done")
78
        print("===========")
79
        docs = docs.unsqueeze(1)
80
        return docs, data
81
    ###
82

83

84
    ###
85
    elif dir_data_out=="data/yelp/":
86
        print("===========")
87
        print("Load CLS.pt and train.json")
88
        print("-----------")
89
        docs = torch.load(dir_data_out+"train_CLS.pt")
90
        print("CLS.pt Done")
91
        print(docs.shape)
92
        print("-----------")
93
        with open(dir_data_out+"train.json") as file:
94
            data = json.load(file)
95
        print("train.json Done")
96
        print("===========")
97
        return docs, data
98
    ###
99

100

101
    ###
102
    elif dir_data_out=="data/yelp_finetune_noword_10000/":
103
        print("===========")
104
        print("Load CLS.pt and train.json")
105
        print("-----------")
106
        docs = torch.load(dir_data_out+"train_CLS.pt")
107
        print("CLS.pt Done")
108
        print(docs.shape)
109
        print("-----------")
110
        with open(dir_data_out+"train.json") as file:
111
            data = json.load(file)
112
        print("train.json Done")
113
        print("===========")
114
        return docs, data
115
    ###
116

117

118
def load_GeneralDomain_docs(dir_data_out):
119
    ###Test
120
    if dir_data_out=="data/open_domain_preprocessed_roberta/":
121
        docs = torch.load(dir_data_out+"opendomain_CLS.pt")
122
        #with open(dir_data_out+"opendomain.json") as file:
123
        #    data = json.load(file)
124
        #print("train.json Done")
125
        print("===========")
126
        docs = docs.unsqueeze(1)
127
        return docs
128
        ###
129

130
    elif dir_data_out=="data/yelp/":
131
        ###
132
        print("===========")
133
        print("Load CLS.pt and train.json")
134
        print("-----------")
135
        docs = torch.load(dir_data_out+"train_CLS.pt")
136
        print("CLS.pt Done")
137
        print(docs.shape)
138
        #print("-----------")
139
        #with open(dir_data_out+"train.json") as file:
140
        #    data = json.load(file)
141
        #print("train.json Done")
142
        #print("===========")
143
        return docs
144
        ###
145

146

147
def load_GeneralDomain_data(dir_data_out):
148
    ###Test
149
    if dir_data_out=="data/open_domain_preprocessed_roberta/":
150
        #docs = torch.load(dir_data_out+"opendomain_CLS.pt")
151
        with open(dir_data_out+"opendomain.json") as file:
152
            data = json.load(file)
153
        #print("train.json Done")
154
        #print("===========")
155
        #docs = docs.unsqueeze(1)
156
        return data
157
        ###
158

159
    elif dir_data_out=="data/yelp/":
160
        ###
161
        #print("===========")
162
        #print("Load CLS.pt and train.json")
163
        #print("-----------")
164
        #docs = torch.load(dir_data_out+"train_CLS.pt")
165
        #print("CLS.pt Done")
166
        #print(docs.shape)
167
        print("-----------")
168
        with open(dir_data_out+"train.json") as file:
169
            data = json.load(file)
170
        print("train.json Done")
171
        print("===========")
172
        return data
173
        ###
174

175
#Load outDomainData
176
###Test
177
#docs, data = load_GeneralDomain("data/open_domain_preprocessed_roberta/")
178
#data = load_GeneralDomain_data("data/open_domain_preprocessed_roberta/")
179
#data = load_GeneralDomain_data("data/yelp/")
180
#docs = load_GeneralDomain_docs("data/yelp/")
181
######
182
###
183
#docs, data = load_GeneralDomain("data/open_domain_preprocessed_roberta/")
184
#docs, data = load_GeneralDomain("data/yelp")
185
docs, data = load_GeneralDomain("data/yelp_finetune_noword_10000/")
186
######
187
if docs.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
188
    #last <s>
189
    #docs = docs[:,0,:].unsqueeze(1)
190
    #mean 13 layers <s>
191
    docs = docs.mean(1).unsqueeze(1)
192
    print(docs.shape)
193
else:
194
    print(docs.shape)
195
######
196

197
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
198
    ###Open
199
    with open(data_dir_indomain+"train.json") as file:
200
        data = json.load(file)
201

202
    ###Preprocess
203
    num_label_list = list()
204
    label_sentence_dict = dict()
205
    num_sentiment_label_list = list()
206
    sentiment_label_dict = dict()
207
    for line in data:
208
        #line["sentence"]
209
        #line["aspect"]
210
        #line["sentiment"]
211
        num_sentiment_label_list.append(line["sentiment"])
212
        num_label_list.append(line["aspect"])
213

214
    num_label = sorted(list(set(num_label_list)))
215
    label_map = {label : i for i , label in enumerate(num_label)}
216
    num_sentiment_label = sorted(list(set(num_sentiment_label_list)))
217
    sentiment_label_map = {label : i for i , label in enumerate(num_sentiment_label)}
218
    print("=======")
219
    print("label_map:")
220
    print(label_map)
221
    print("=======")
222
    print("=======")
223
    print("sentiment_label_map:")
224
    print(sentiment_label_map)
225
    print("=======")
226

227
    ###Create data: 1 choosed data along with the rest of 7 class data
228

229
    '''
230
    all_input_ids = list()
231
    all_input_mask = list()
232
    all_segment_ids = list()
233
    all_lm_labels_ids = list()
234
    all_is_next = list()
235
    all_tail_idxs = list()
236
    all_sentence_labels = list()
237
    '''
238
    cur_tensors_list = list()
239
    #print(list(label_map.values()))
240
    candidate_label_list = list(label_map.values())
241
    candidate_sentiment_label_list = list(sentiment_label_map.values())
242
    all_type_sentence = [0]*len(candidate_label_list)
243
    all_type_sentiment_sentence = [0]*len(candidate_sentiment_label_list)
244
    for line in data:
245
        #line["sentence"]
246
        #line["aspect"]
247
        sentiment = line["sentiment"]
248
        sentence = line["sentence"]
249
        label = line["aspect"]
250

251

252
        tokens_a = tokenizer.tokenize(sentence)
253
        #input_ids = tokenizer.encode(sentence, add_special_tokens=False)
254
        '''
255
        if "</s>" in tokens_a:
256
            print("Have more than 1 </s>")
257
            #tokens_a[tokens_a.index("<s>")] = "s"
258
            for i in range(len(tokens_a)):
259
                if tokens_a[i] == "</s>":
260
                    tokens_a[i] == "s"
261
        '''
262

263

264
        # tokenize
265
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
266
        # transform sample to features
267
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
268

269
        cur_tensors = (torch.tensor(cur_features.input_ids),
270
                       torch.tensor(cur_features.input_ids_org),
271
                       torch.tensor(cur_features.input_mask),
272
                       torch.tensor(cur_features.segment_ids),
273
                       torch.tensor(cur_features.lm_label_ids),
274
                       torch.tensor(0),
275
                       torch.tensor(cur_features.tail_idxs),
276
                       torch.tensor(label_map[label]),
277
                       torch.tensor(sentiment_label_map[sentiment]))
278

279
        cur_tensors_list.append(cur_tensors)
280

281
        ###
282
        if label_map[label] in candidate_label_list:
283
            all_type_sentence[label_map[label]]=cur_tensors
284
            candidate_label_list.remove(label_map[label])
285

286
        if sentiment_label_map[sentiment] in candidate_sentiment_label_list:
287
            #print("----")
288
            #print(sentiment_label_map[sentiment])
289
            #print("----")
290
            all_type_sentiment_sentence[sentiment_label_map[sentiment]]=cur_tensors
291
            candidate_sentiment_label_list.remove(sentiment_label_map[sentiment])
292
        ###
293

294

295

296

297
    '''
298
        all_input_ids.append(torch.tensor(cur_features.input_ids))
299
        all_input_mask.append(torch.tensor(cur_features.input_mask))
300
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
301
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
302
        all_is_next.append(torch.tensor(0))
303
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
304
        all_sentence_labels.append(torch.tensor(label_map[label]))
305

306
    cur_tensors = (torch.stack(all_input_ids),
307
                   torch.stack(all_input_mask),
308
                   torch.stack(all_segment_ids),
309
                   torch.stack(all_lm_labels_ids),
310
                   torch.stack(all_is_next),
311
                   torch.stack(all_tail_idxs),
312
                   torch.stack(all_sentence_labels))
313
    '''
314

315

316
    '''
317
    print("=====")
318
    print(candidate_label_list)
319
    print("---")
320
    print(all_type_sentence)
321
    print("---")
322
    print(len(cur_tensors_list))
323
    exit()
324
    '''
325

326
    #return cur_tensors
327
    #for line in all_type_sentiment_sentence:
328
    #    print(line[-1])
329
    #exit()
330
    return all_type_sentiment_sentence, cur_tensors_list
331

332

333
def in_Domain_Task_Data_binary(data_dir_indomain, tokenizer, max_seq_length):
334
    ###Open
335
    with open(data_dir_indomain+"train.json") as file:
336
        data = json.load(file)
337

338
    ###Preprocess
339
    num_label_list = list()
340
    label_sentence_dict = dict()
341
    for line in data:
342
        #line["sentence"]
343
        #line["aspect"]
344
        #line["sentiment"]
345
        num_label_list.append(line["aspect"])
346
        try:
347
            label_sentence_dict[line["aspect"]].append([line["sentence"]])
348
        except:
349
            label_sentence_dict[line["aspect"]] = [line["sentence"]]
350

351
    num_label = sorted(list(set(num_label_list)))
352
    label_map = {label : i for i , label in enumerate(num_label)}
353

354
    ###Create data: 1 choosed data along with the rest of 7 class data
355
    all_cur_tensors = list()
356
    for line in data:
357
        #line["sentence"]
358
        #line["aspect"]
359
        #line["sentiment"]
360
        sentence = line["sentence"]
361
        label = line["aspect"]
362
        sentence_out = [(random.choice(label_sentence_dict[label_out])[0], label_out) for label_out in num_label if label_out!=label]
363
        all_sentence = [(sentence, label)] + sentence_out #1st sentence is choosed
364

365
        all_input_ids = list()
366
        all_input_mask = list()
367
        all_segment_ids = list()
368
        all_lm_labels_ids = list()
369
        all_is_next = list()
370
        all_tail_idxs = list()
371
        all_sentence_labels = list()
372
        for id, sentence_label in enumerate(all_sentence):
373
            #tokens_a = tokenizer.tokenize(sentence_label[0])
374
            tokens_a = tokenizer.tokenize(sentence_label[0])
375
            '''
376
            if "</s>" in tokens_a:
377
                print("Have more than 1 </s>")
378
                for i in range(len(tokens_a)):
379
                    if tokens_a[i] == "</s>":
380
                        tokens_a[i] = "s"
381
            '''
382

383
            # tokenize
384
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
385
            # transform sample to features
386
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
387

388
            all_input_ids.append(torch.tensor(cur_features.input_ids))
389
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
390
            all_input_mask.append(torch.tensor(cur_features.input_mask))
391
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
392
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
393
            all_is_next.append(torch.tensor(0))
394
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
395
            all_sentence_labels.append(torch.tensor(label_map[sentence_label[1]]))
396

397
        cur_tensors = (torch.stack(all_input_ids),
398
                       torch.stack(all_input_ids_org),
399
                       torch.stack(all_input_mask),
400
                       torch.stack(all_segment_ids),
401
                       torch.stack(all_lm_labels_ids),
402
                       torch.stack(all_is_next),
403
                       torch.stack(all_tail_idxs),
404
                       torch.stack(all_sentence_labels))
405

406
        all_cur_tensors.append(cur_tensors)
407

408
    return all_cur_tensors
409

410

411

412
def AugmentationData_Domain(top_k, tokenizer, max_seq_length):
413
    #top_k_shape = top_k.indices.shape
414
    #ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
415
    top_k_shape = top_k["indices"].shape
416
    ids = top_k["indices"].reshape(top_k_shape[0]*top_k_shape[1]).tolist()
417

418
    all_input_ids = list()
419
    all_input_ids_org = list()
420
    all_input_mask = list()
421
    all_segment_ids = list()
422
    all_lm_labels_ids = list()
423
    all_is_next = list()
424
    all_tail_idxs = list()
425

426
    for id, i in enumerate(ids):
427
        t1 = data[str(i)]['sentence']
428

429
        #tokens_a = tokenizer.tokenize(t1)
430
        tokens_a = tokenizer.tokenize(t1)
431
        '''
432
        if "</s>" in tokens_a:
433
            print("Have more than 1 </s>")
434
            #tokens_a[tokens_a.index("<s>")] = "s"
435
            for i in range(len(tokens_a)):
436
                if tokens_a[i] == "</s>":
437
                    tokens_a[i] = "s"
438
        '''
439

440
        # tokenize
441
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
442

443
        # transform sample to features
444
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
445

446
        all_input_ids.append(torch.tensor(cur_features.input_ids))
447
        all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
448
        all_input_mask.append(torch.tensor(cur_features.input_mask))
449
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
450
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
451
        all_is_next.append(torch.tensor(0))
452
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
453

454

455
    cur_tensors = (torch.stack(all_input_ids),
456
                   torch.stack(all_input_ids_org),
457
                   torch.stack(all_input_mask),
458
                   torch.stack(all_segment_ids),
459
                   torch.stack(all_lm_labels_ids),
460
                   torch.stack(all_is_next),
461
                   torch.stack(all_tail_idxs))
462

463
    return cur_tensors
464

465

466
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
467
    top_k_shape = top_k["indices"].shape
468
    sentence_ids = top_k["indices"]
469

470
    all_input_ids = list()
471
    all_input_ids_org = list()
472
    all_input_mask = list()
473
    all_segment_ids = list()
474
    all_lm_labels_ids = list()
475
    all_is_next = list()
476
    all_tail_idxs = list()
477
    all_sentence_labels = list()
478
    all_sentiment_labels = list()
479

480
    add_org = tuple(t.to('cpu') for t in add_org)
481
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
482
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
483

484
    ###
485
    #print("input_ids_",input_ids_.shape)
486
    #print("---")
487
    #print("sentence_ids",sentence_ids.shape)
488
    #print("---")
489
    #print("sentence_label_",sentence_label_.shape)
490
    #exit()
491

492

493
    for id_1, sent in enumerate(sentence_ids):
494
        for id_2, sent_id in enumerate(sent):
495

496
            t1 = data[str(int(sent_id))]['sentence']
497

498
            tokens_a = tokenizer.tokenize(t1)
499

500
            # tokenize
501
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
502

503
            # transform sample to features
504
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
505

506
            all_input_ids.append(torch.tensor(cur_features.input_ids))
507
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
508
            all_input_mask.append(torch.tensor(cur_features.input_mask))
509
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
510
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
511
            all_is_next.append(torch.tensor(0))
512
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
513
            all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
514
            all_sentiment_labels.append(torch.tensor(sentiment_label_[id_1]))
515
            '''
516
            #if len(sentence_label_) != len(sentence_label_):
517
            #    print(len(sentence_label_) != len(sentence_label_))
518
            try:
519
                all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
520
            except:
521
                #all_sentence_labels.append(torch.tensor([0]))
522
                print(sentence_ids)
523
                print(sentence_label_)
524
                print("==========================")
525
                print("input_ids_",input_ids_.shape)
526
                print("---")
527
                print("sentence_ids",sentence_ids.shape)
528
                print("---")
529
                print("sentence_label_",sentence_label_.shape)
530
                exit()
531
            '''
532

533
        all_input_ids.append(input_ids_[id_1])
534
        all_input_ids_org.append(input_ids_org_[id_1])
535
        all_input_mask.append(input_mask_[id_1])
536
        all_segment_ids.append(segment_ids_[id_1])
537
        all_lm_labels_ids.append(lm_label_ids_[id_1])
538
        all_is_next.append(is_next_[id_1])
539
        all_tail_idxs.append(tail_idxs_[id_1])
540
        all_sentence_labels.append(sentence_label_[id_1])
541
        all_sentiment_labels.append(sentiment_label_[id_1])
542

543

544
    cur_tensors = (torch.stack(all_input_ids),
545
                   torch.stack(all_input_ids_org),
546
                   torch.stack(all_input_mask),
547
                   torch.stack(all_segment_ids),
548
                   torch.stack(all_lm_labels_ids),
549
                   torch.stack(all_is_next),
550
                   torch.stack(all_tail_idxs),
551
                   torch.stack(all_sentence_labels),
552
                   torch.stack(all_sentiment_labels)
553
                   )
554

555

556
    return cur_tensors
557

558

559
def AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=None, max_seq_length=None, add_org=None, in_task_rep=None):
560
    '''
561
    top_k_shape = top_k.indices.shape
562
    sentence_ids = top_k.indices
563
    '''
564
    #top_k_shape = top_k["indices"].shape
565
    #sentence_ids = top_k["indices"]
566

567

568
    #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
569
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = add_org
570

571
    #print(sentence_label_.shape)
572

573
    all_sentence_binary_label = list()
574
    all_in_task_rep_comb = list()
575
    for id_1, num in enumerate(sentence_label_):
576
        #print([sentence_label_==num])
577
        #print(type([sentence_label_==num]))
578
        sentence_label_int = (sentence_label_==num).to(torch.long)
579
        #print(in_task_rep[id_1].shape)
580
        in_task_rep_append = in_task_rep[id_1].unsqueeze(0).expand(in_task_rep.shape[0],-1)
581
        in_task_rep_comb = torch.cat((in_task_rep_append,in_task_rep),-1)
582
        #print(in_task_rep_comb.shape)
583
        #exit()
584
        #sentence_label_int = sentence_label_int.to(torch.float32)
585
        #print(sentence_label_int)
586
        #exit()
587
        #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
588
        #all_sentence_binary_label.append(torch.tensor([1 if num==iid else 0 for iid in sentence_label_]))
589
        all_sentence_binary_label.append(sentence_label_int)
590
        all_in_task_rep_comb.append(in_task_rep_comb)
591
    all_sentence_binary_label = torch.stack(all_sentence_binary_label)
592
    all_in_task_rep_comb = torch.stack(all_in_task_rep_comb)
593

594
    cur_tensors = (all_in_task_rep_comb, all_sentence_binary_label)
595

596
    return cur_tensors
597

598

599
    '''
600
    all_input_ids_batch = list()
601
    all_input_ids_org_batch = list()
602
    all_input_mask_batch = list()
603
    all_segment_ids_batch = list()
604
    all_lm_labels_ids_batch = list()
605
    all_is_next_batch = list()
606
    all_tail_idxs_batch = list()
607
    all_sentence_labels_batch = list()
608
    all_sentence_binary_label_batch = list()
609

610
    for id_1, sent in enumerate(sentence_ids):
611
        all_input_ids = list()
612
        all_input_ids_org = list()
613
        all_input_mask = list()
614
        all_segment_ids = list()
615
        all_lm_labels_ids = list()
616
        all_is_next = list()
617
        all_tail_idxs = list()
618
        all_sentence_labels = list()
619
        all_sentence_binary_label = list()
620

621
        #print(sent)
622
        #print(sent.shape)
623
        #exit()
624
        for id_2, sent_id in enumerate(sent):
625

626
            t1 = data[str(int(sent_id))]['sentence']
627

628
            #tokens_a = tokenizer.tokenize(t1)
629
            tokens_a = tokenizer.tokenize(t1)
630

631
            # tokenize
632
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
633

634
            # transform sample to features
635
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
636

637
            all_input_ids.append(torch.tensor(cur_features.input_ids))
638
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
639
            all_input_mask.append(torch.tensor(cur_features.input_mask))
640
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
641
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
642
            all_is_next.append(torch.tensor(0))
643
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
644
            all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
645
            #all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
646
        all_sentence_binary_label.append(torch.tensor([1 if sentence_label_[id_1]==iid else 0 for iid in sentence_label_]))
647
        #all_sentence_binary_label = torch.tensor(all_sentence_binary_label)
648
        #print(all_sentence_binary_label)
649
        #print(all_sentence_binary_label[0].shape)
650
        #exit()
651

652
        all_input_ids_batch.append(torch.stack(all_input_ids))
653
        all_input_ids_org_batch.append(torch.stack(all_input_ids_org))
654
        all_input_mask_batch.append(torch.stack(all_input_mask))
655
        all_segment_ids_batch.append(torch.stack(all_segment_ids))
656
        all_lm_labels_ids_batch.append(torch.stack(all_lm_labels_ids))
657
        all_is_next_batch.append(torch.stack(all_is_next))
658
        all_tail_idxs_batch.append(torch.stack(all_tail_idxs))
659
        all_sentence_labels_batch.append(torch.stack(all_sentence_labels))
660
        all_sentence_binary_label_batch.append(torch.stack(all_sentence_binary_label))
661
        #print(all_sentence_binary_label_batch)
662
        #print(all_sentence_binary_label_batch[0].shape)
663
        #exit()
664
    #print("===")
665
    #print(all_sentence_binary_label_batch)
666
    #print(len(all_sentence_binary_label_batch), len(all_sentence_binary_label_batch[0]))
667
    #exit()
668

669

670

671
    cur_tensors = (torch.stack(all_input_ids_batch),
672
                   torch.stack(all_input_ids_org_batch),
673
                   torch.stack(all_input_mask_batch),
674
                   torch.stack(all_segment_ids_batch),
675
                   torch.stack(all_lm_labels_ids_batch),
676
                   torch.stack(all_is_next_batch),
677
                   torch.stack(all_tail_idxs_batch),
678
                   torch.stack(all_sentence_labels_batch),
679
                   torch.stack(all_sentence_binary_label_batch)
680
                   )
681

682

683
    return cur_tensors
684
    '''
685

686

687

688
class Dataset_noNext(Dataset):
689
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
690

691
        self.vocab_size = tokenizer.vocab_size
692
        self.tokenizer = tokenizer
693
        self.seq_len = seq_len
694
        self.on_memory = on_memory
695
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
696
        self.corpus_path = corpus_path
697
        self.encoding = encoding
698
        self.current_doc = 0  # to avoid random sentence from same doc
699

700
        # for loading samples directly from file
701
        self.sample_counter = 0  # used to keep track of full epochs on file
702
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
703

704
        # for loading samples in memory
705
        self.current_random_doc = 0
706
        self.num_docs = 0
707
        self.sample_to_doc = [] # map sample index to doc and line
708

709
        # load samples into memory
710
        if on_memory:
711
            self.all_docs = []
712
            doc = []
713
            self.corpus_lines = 0
714
            with open(corpus_path, "r", encoding=encoding) as f:
715
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
716
                    line = line.strip()
717
                    if line == "":
718
                        self.all_docs.append(doc)
719
                        doc = []
720
                        #remove last added sample because there won't be a subsequent line anymore in the doc
721
                        self.sample_to_doc.pop()
722
                    else:
723
                        #store as one sample
724
                        sample = {"doc_id": len(self.all_docs),
725
                                  "line": len(doc)}
726
                        self.sample_to_doc.append(sample)
727
                        doc.append(line)
728
                        self.corpus_lines = self.corpus_lines + 1
729

730
            # if last row in file is not empty
731
            if self.all_docs[-1] != doc:
732
                self.all_docs.append(doc)
733
                self.sample_to_doc.pop()
734

735
            self.num_docs = len(self.all_docs)
736

737
        # load samples later lazily from disk
738
        else:
739
            if self.corpus_lines is None:
740
                with open(corpus_path, "r", encoding=encoding) as f:
741
                    self.corpus_lines = 0
742
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
743
                        if line.strip() == "":
744
                            self.num_docs += 1
745
                        else:
746
                            self.corpus_lines += 1
747

748
                    # if doc does not end with empty line
749
                    if line.strip() != "":
750
                        self.num_docs += 1
751

752
            self.file = open(corpus_path, "r", encoding=encoding)
753
            self.random_file = open(corpus_path, "r", encoding=encoding)
754

755
    def __len__(self):
756
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
757
        return self.corpus_lines - self.num_docs - 1
758

759
    def __getitem__(self, item):
760
        cur_id = self.sample_counter
761
        self.sample_counter += 1
762
        if not self.on_memory:
763
            # after one epoch we start again from beginning of file
764
            if cur_id != 0 and (cur_id % len(self) == 0):
765
                self.file.close()
766
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
767

768
        #t1, t2, is_next_label = self.random_sent(item)
769
        t1, is_next_label = self.random_sent(item)
770
        if is_next_label == None:
771
            is_next_label = 0
772

773

774
        #tokens_a = self.tokenizer.tokenize(t1)
775
        tokens_a = tokenizer.tokenize(t1)
776
        '''
777
        if "</s>" in tokens_a:
778
            print("Have more than 1 </s>")
779
            #tokens_a[tokens_a.index("<s>")] = "s"
780
            for i in range(len(tokens_a)):
781
                if tokens_a[i] == "</s>":
782
                    tokens_a[i] = "s"
783
        '''
784
        #tokens_b = self.tokenizer.tokenize(t2)
785

786
        # tokenize
787
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
788

789
        # transform sample to features
790
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
791

792
        cur_tensors = (torch.tensor(cur_features.input_ids),
793
                       torch.tensor(cur_features.input_ids_org),
794
                       torch.tensor(cur_features.input_mask),
795
                       torch.tensor(cur_features.segment_ids),
796
                       torch.tensor(cur_features.lm_label_ids),
797
                       torch.tensor(cur_features.is_next),
798
                       torch.tensor(cur_features.tail_idxs))
799

800
        return cur_tensors
801

802
    def random_sent(self, index):
803
        """
804
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
805
        from one doc. With 50% the second sentence will be a random one from another doc.
806
        :param index: int, index of sample.
807
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
808
        """
809
        t1, t2 = self.get_corpus_line(index)
810
        return t1, None
811

812
    def get_corpus_line(self, item):
813
        """
814
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
815
        :param item: int, index of sample.
816
        :return: (str, str), two subsequent sentences from corpus
817
        """
818
        t1 = ""
819
        t2 = ""
820
        assert item < self.corpus_lines
821
        if self.on_memory:
822
            sample = self.sample_to_doc[item]
823
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
824
            # used later to avoid random nextSentence from same doc
825
            self.current_doc = sample["doc_id"]
826
            return t1, t2
827
            #return t1
828
        else:
829
            if self.line_buffer is None:
830
                # read first non-empty line of file
831
                while t1 == "" :
832
                    t1 = next(self.file).strip()
833
            else:
834
                # use t2 from previous iteration as new t1
835
                t1 = self.line_buffer
836
                # skip empty rows that are used for separating documents and keep track of current doc id
837
                while t1 == "":
838
                    t1 = next(self.file).strip()
839
                    self.current_doc = self.current_doc+1
840
            self.line_buffer = next(self.file).strip()
841

842
        assert t1 != ""
843
        return t1, t2
844

845

846
    def get_random_line(self):
847
        """
848
        Get random line from another document for nextSentence task.
849
        :return: str, content of one line
850
        """
851
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
852
        # corpora. However, just to be careful, we try to make sure that
853
        # the random document is not the same as the document we're processing.
854
        for _ in range(10):
855
            if self.on_memory:
856
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
857
                rand_doc = self.all_docs[rand_doc_idx]
858
                line = rand_doc[random.randrange(len(rand_doc))]
859
            else:
860
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
861
                #pick random line
862
                for _ in range(rand_index):
863
                    line = self.get_next_line()
864
            #check if our picked random line is really from another doc like we want it to be
865
            if self.current_random_doc != self.current_doc:
866
                break
867
        return line
868

869
    def get_next_line(self):
870
        """ Gets next line of random_file and starts over when reaching end of file"""
871
        try:
872
            line = next(self.random_file).strip()
873
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
874
            if line == "":
875
                self.current_random_doc = self.current_random_doc + 1
876
                line = next(self.random_file).strip()
877
        except StopIteration:
878
            self.random_file.close()
879
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
880
            line = next(self.random_file).strip()
881
        return line
882

883

884
class InputExample(object):
885
    """A single training/test example for the language model."""
886

887
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
888
        """Constructs a InputExample.
889
        Args:
890
            guid: Unique id for the example.
891
            tokens_a: string. The untokenized text of the first sequence. For single
892
            sequence tasks, only this sequence must be specified.
893
            tokens_b: (Optional) string. The untokenized text of the second sequence.
894
            Only must be specified for sequence pair tasks.
895
            label: (Optional) string. The label of the example. This should be
896
            specified for train and dev examples, but not for test examples.
897
        """
898
        self.guid = guid
899
        self.tokens_a = tokens_a
900
        self.tokens_b = tokens_b
901
        self.is_next = is_next  # nextSentence
902
        self.lm_labels = lm_labels  # masked words for language model
903

904

905
class InputFeatures(object):
906
    """A single set of features of data."""
907

908
    def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
909
        self.input_ids = input_ids
910
        self.input_ids_org = input_ids_org
911
        self.input_mask = input_mask
912
        self.segment_ids = segment_ids
913
        self.is_next = is_next
914
        self.lm_label_ids = lm_label_ids
915
        self.tail_idxs = tail_idxs
916

917

918
def random_word(tokens, tokenizer):
919
    """
920
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
921
    :param tokens: list of str, tokenized sentence.
922
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
923
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
924
    """
925
    output_label = []
926

927
    for i, token in enumerate(tokens):
928

929
        prob = random.random()
930
        # mask token with 15% probability
931
        if prob < 0.15:
932
            prob /= 0.15
933
            #candidate_id = random.randint(0,tokenizer.vocab_size)
934
            #print(tokenizer.convert_ids_to_tokens(candidate_id))
935

936

937
            # 80% randomly change token to mask token
938
            if prob < 0.8:
939
                #tokens[i] = "[MASK]"
940
                tokens[i] = "<mask>"
941

942
            # 10% randomly change token to random token
943
            elif prob < 0.9:
944
                #tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
945
                #tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
946
                candidate_id = random.randint(0,tokenizer.vocab_size)
947
                w = tokenizer.convert_ids_to_tokens(candidate_id)
948
                '''
949
                if tokens[i] == None:
950
                    candidate_id = 100
951
                    w = tokenizer.convert_ids_to_tokens(candidate_id)
952
                '''
953
                tokens[i] = w
954

955

956
            # -> rest 10% randomly keep current token
957

958
            # append current token to output (we will predict these later)
959
            try:
960
                #output_label.append(tokenizer.vocab[token])
961
                w = tokenizer.convert_tokens_to_ids(token)
962
                if w!= None:
963
                    output_label.append(w)
964
                else:
965
                    print("Have no this tokens in ids")
966
                    exit()
967
            except KeyError:
968
                # For unknown words (should not occur with BPE vocab)
969
                #output_label.append(tokenizer.vocab["<unk>"])
970
                w = tokenizer.convert_tokens_to_ids("<unk>")
971
                output_label.append(w)
972
                logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
973
        else:
974
            # no masking token (will be ignored by loss function later)
975
            output_label.append(-1)
976

977
    return tokens, output_label
978

979

980
def convert_example_to_features(example, max_seq_length, tokenizer):
981
    """
982
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
983
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
984
    :param example: InputExample, containing sentence input as strings and is_next label
985
    :param max_seq_length: int, maximum length of sequence.
986
    :param tokenizer: Tokenizer
987
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
988
    """
989
    #now tokens_a is input_ids
990
    tokens_a = example.tokens_a
991
    tokens_b = example.tokens_b
992
    # Modifies `tokens_a` and `tokens_b` in place so that the total
993
    # length is less than the specified length.
994
    # Account for [CLS], [SEP], [SEP] with "- 3"
995
    #_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
996
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
997

998
    #print(tokens_a)
999
    tokens_a_org = tokens_a.copy()
1000
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
1001
    #print("----")
1002
    #print(tokens_a)
1003
    #print(tokens_a_org)
1004
    #exit()
1005
    #print(t1_label)
1006
    #exit()
1007
    #tokens_b, t2_label = random_word(tokens_b, tokenizer)
1008
    # concatenate lm labels and account for CLS, SEP, SEP
1009
    #lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
1010
    lm_label_ids = ([-1] + t1_label + [-1])
1011

1012
    # The convention in BERT is:
1013
    # (a) For sequence pairs:
1014
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
1015
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
1016
    # (b) For single sequences:
1017
    #  tokens:   [CLS] the dog is hairy . [SEP]
1018
    #  type_ids: 0   0   0   0  0     0 0
1019
    #
1020
    # Where "type_ids" are used to indicate whether this is the first
1021
    # sequence or the second sequence. The embedding vectors for `type=0` and
1022
    # `type=1` were learned during pre-training and are added to the wordpiece
1023
    # embedding vector (and position vector). This is not *strictly* necessary
1024
    # since the [SEP] token unambigiously separates the sequences, but it makes
1025
    # it easier for the model to learn the concept of sequences.
1026
    #
1027
    # For classification tasks, the first vector (corresponding to [CLS]) is
1028
    # used as as the "sentence vector". Note that this only makes sense because
1029
    # the entire model is fine-tuned.
1030
    tokens = []
1031
    tokens_org = []
1032
    segment_ids = []
1033
    tokens.append("<s>")
1034
    tokens_org.append("<s>")
1035
    segment_ids.append(0)
1036
    for i, token in enumerate(tokens_a):
1037
        if token!="</s>":
1038
            tokens.append(tokens_a[i])
1039
            tokens_org.append(tokens_a_org[i])
1040
            segment_ids.append(0)
1041
        else:
1042
            tokens.append("s")
1043
            tokens_org.append("s")
1044
            segment_ids.append(0)
1045
    tokens.append("</s>")
1046
    tokens_org.append("</s>")
1047
    segment_ids.append(0)
1048

1049
    #tokens.append("[SEP]")
1050
    #segment_ids.append(1)
1051

1052
    #input_ids = tokenizer.convert_tokens_to_ids(tokens)
1053
    input_ids = tokenizer.encode(tokens, add_special_tokens=False)
1054
    input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
1055
    tail_idxs = len(input_ids)+1
1056

1057
    #print(input_ids)
1058
    input_ids = [w if w!=None else 0 for w in input_ids]
1059
    input_ids_org = [w if w!=None else 0 for w in input_ids_org]
1060
    #print(input_ids)
1061
    #exit()
1062

1063
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
1064
    # tokens are attended to.
1065
    input_mask = [1] * len(input_ids)
1066

1067
    # Zero-pad up to the sequence length.
1068
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
1069
    while len(input_ids) < max_seq_length:
1070
        input_ids.append(pad_id)
1071
        input_ids_org.append(pad_id)
1072
        input_mask.append(0)
1073
        segment_ids.append(0)
1074
        lm_label_ids.append(-1)
1075

1076

1077
    assert len(input_ids) == max_seq_length
1078
    assert len(input_ids_org) == max_seq_length
1079
    assert len(input_mask) == max_seq_length
1080
    assert len(segment_ids) == max_seq_length
1081
    assert len(lm_label_ids) == max_seq_length
1082

1083
    '''
1084
    if example.guid < 5:
1085
        logger.info("*** Example ***")
1086
        logger.info("guid: %s" % (example.guid))
1087
        logger.info("tokens: %s" % " ".join(
1088
                [str(x) for x in tokens]))
1089
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
1090
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
1091
        logger.info(
1092
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
1093
        logger.info("LM label: %s " % (lm_label_ids))
1094
        logger.info("Is next sentence label: %s " % (example.is_next))
1095
    '''
1096

1097
    features = InputFeatures(input_ids=input_ids,
1098
                             input_ids_org = input_ids_org,
1099
                             input_mask=input_mask,
1100
                             segment_ids=segment_ids,
1101
                             lm_label_ids=lm_label_ids,
1102
                             is_next=example.is_next,
1103
                             tail_idxs=tail_idxs)
1104
    return features
1105

1106

1107
def main():
1108
    parser = argparse.ArgumentParser()
1109

1110
    ## Required parameters
1111
    parser.add_argument("--data_dir_indomain",
1112
                        default=None,
1113
                        type=str,
1114
                        required=True,
1115
                        help="The input train corpus.(In Domain)")
1116
    parser.add_argument("--data_dir_outdomain",
1117
                        default=None,
1118
                        type=str,
1119
                        required=True,
1120
                        help="The input train corpus.(Out Domain)")
1121
    parser.add_argument("--pretrain_model", default=None, type=str, required=True,
1122
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
1123
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
1124
    parser.add_argument("--output_dir",
1125
                        default=None,
1126
                        type=str,
1127
                        required=True,
1128
                        help="The output directory where the model checkpoints will be written.")
1129
    parser.add_argument("--augment_times",
1130
                        default=None,
1131
                        type=int,
1132
                        required=True,
1133
                        help="Default batch_size/augment_times to save model")
1134
    ## Other parameters
1135
    parser.add_argument("--max_seq_length",
1136
                        default=128,
1137
                        type=int,
1138
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
1139
                             "Sequences longer than this will be truncated, and sequences shorter \n"
1140
                             "than this will be padded.")
1141
    parser.add_argument("--do_train",
1142
                        action='store_true',
1143
                        help="Whether to run training.")
1144
    parser.add_argument("--train_batch_size",
1145
                        default=32,
1146
                        type=int,
1147
                        help="Total batch size for training.")
1148
    parser.add_argument("--learning_rate",
1149
                        default=3e-5,
1150
                        type=float,
1151
                        help="The initial learning rate for Adam.")
1152
    parser.add_argument("--num_train_epochs",
1153
                        default=3.0,
1154
                        type=float,
1155
                        help="Total number of training epochs to perform.")
1156
    parser.add_argument("--warmup_proportion",
1157
                        default=0.1,
1158
                        type=float,
1159
                        help="Proportion of training to perform linear learning rate warmup for. "
1160
                             "E.g., 0.1 = 10%% of training.")
1161
    parser.add_argument("--no_cuda",
1162
                        action='store_true',
1163
                        help="Whether not to use CUDA when available")
1164
    parser.add_argument("--on_memory",
1165
                        action='store_true',
1166
                        help="Whether to load train samples into memory or use disk")
1167
    parser.add_argument("--do_lower_case",
1168
                        action='store_true',
1169
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
1170
    parser.add_argument("--local_rank",
1171
                        type=int,
1172
                        default=-1,
1173
                        help="local_rank for distributed training on gpus")
1174
    parser.add_argument('--seed',
1175
                        type=int,
1176
                        default=42,
1177
                        help="random seed for initialization")
1178
    parser.add_argument('--gradient_accumulation_steps',
1179
                        type=int,
1180
                        default=1,
1181
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
1182
    parser.add_argument('--fp16',
1183
                        action='store_true',
1184
                        help="Whether to use 16-bit float precision instead of 32-bit")
1185
    parser.add_argument('--loss_scale',
1186
                        type = float, default = 0,
1187
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
1188
                        "0 (default value): dynamic loss scaling.\n"
1189
                        "Positive power of 2: static loss scaling value.\n")
1190
    ####
1191
    parser.add_argument("--num_labels_task",
1192
                        default=None, type=int,
1193
                        required=True,
1194
                        help="num_labels_task")
1195
    parser.add_argument("--weight_decay",
1196
                        default=0.0,
1197
                        type=float,
1198
                        help="Weight decay if we apply some.")
1199
    parser.add_argument("--adam_epsilon",
1200
                        default=1e-8,
1201
                        type=float,
1202
                        help="Epsilon for Adam optimizer.")
1203
    parser.add_argument("--max_grad_norm",
1204
                        default=1.0,
1205
                        type=float,
1206
                        help="Max gradient norm.")
1207
    parser.add_argument('--fp16_opt_level',
1208
                        type=str,
1209
                        default='O1',
1210
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
1211
                             "See details at https://nvidia.github.io/apex/amp.html")
1212
    parser.add_argument("--task",
1213
                        default=None,
1214
                        type=int,
1215
                        required=True,
1216
                        help="Choose Task")
1217
    ####
1218

1219
    args = parser.parse_args()
1220

1221
    if args.local_rank == -1 or args.no_cuda:
1222
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
1223
        n_gpu = torch.cuda.device_count()
1224
    else:
1225
        torch.cuda.set_device(args.local_rank)
1226
        device = torch.device("cuda", args.local_rank)
1227
        n_gpu = 1
1228
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
1229
        torch.distributed.init_process_group(backend='nccl')
1230
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
1231
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
1232

1233
    if args.gradient_accumulation_steps < 1:
1234
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
1235
                            args.gradient_accumulation_steps))
1236

1237
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
1238

1239
    random.seed(args.seed)
1240
    np.random.seed(args.seed)
1241
    torch.manual_seed(args.seed)
1242
    if n_gpu > 0:
1243
        torch.cuda.manual_seed_all(args.seed)
1244

1245
    if not args.do_train:
1246
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
1247

1248
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
1249
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
1250
    if not os.path.exists(args.output_dir):
1251
        os.makedirs(args.output_dir)
1252

1253
    #tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
1254
    tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
1255

1256

1257
    #train_examples = None
1258
    num_train_optimization_steps = None
1259
    if args.do_train:
1260
        print("Loading Train Dataset", args.data_dir_indomain)
1261
        #train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
1262
        all_type_sentence, train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
1263
        num_train_optimization_steps = int(
1264
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
1265
        if args.local_rank != -1:
1266
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
1267

1268

1269

1270
    # Prepare model
1271
    model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1272
    #model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
1273
    model.to(device)
1274

1275

1276

1277
    # Prepare optimizer
1278
    if args.do_train:
1279
        param_optimizer = list(model.named_parameters())
1280
        '''
1281
        for par in param_optimizer:
1282
            print(par[0])
1283
        exit()
1284
        '''
1285
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
1286
        optimizer_grouped_parameters = [
1287
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
1288
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
1289
            ]
1290
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
1291
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
1292

1293
        if args.fp16:
1294
            try:
1295
                from apex import amp
1296
            except ImportError:
1297
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
1298
                exit()
1299

1300
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
1301

1302

1303
        if n_gpu > 1:
1304
            model = torch.nn.DataParallel(model)
1305

1306
        if args.local_rank != -1:
1307
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
1308

1309

1310

1311
    global_step = 0
1312
    if args.do_train:
1313
        logger.info("***** Running training *****")
1314
        logger.info("  Num examples = %d", len(train_dataset))
1315
        logger.info("  Batch size = %d", args.train_batch_size)
1316
        logger.info("  Num steps = %d", num_train_optimization_steps)
1317

1318
        if args.local_rank == -1:
1319
            train_sampler = RandomSampler(train_dataset)
1320
            #all_type_sentence_sampler = RandomSampler(all_type_sentence)
1321
        else:
1322
            #TODO: check if this works with current data generator from disk that relies on next(file)
1323
            # (it doesn't return item back by index)
1324
            train_sampler = DistributedSampler(train_dataset)
1325
            #all_type_sentence_sampler = DistributedSampler(all_type_sentence)
1326
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
1327
        #all_type_sentence_dataloader = DataLoader(all_type_sentence, sampler=all_type_sentence_sampler, batch_size=len(all_type_sentence_label))
1328

1329
        output_loss_file = os.path.join(args.output_dir, "loss")
1330
        loss_fout = open(output_loss_file, 'w')
1331

1332

1333
        output_loss_file_no_pseudo = os.path.join(args.output_dir, "loss_no_pseudo")
1334
        loss_fout_no_pseudo = open(output_loss_file_no_pseudo, 'w')
1335
        model.train()
1336

1337

1338
        #print(model.parameters)
1339
        #print(model.modules.module)
1340
        #print(model.modules.module.RobertaForMaskedLMDomainTask)
1341
        #print(list(model.named_parameters()))
1342
        #print([i for i,j in model.named_parameters()])
1343
        #print(model.parameters["RobertaForMaskedLMDomainTask"])
1344

1345
        ##Need to confirm use input_ids or input_ids_org !!!!!!!!
1346
        ###
1347
        #[10000000, 13, 768] ---> [1000000, 768] --> [1,,] --> [batch_size,,]
1348
        ###
1349
        ###
1350
        ###
1351
        #print(docs.shape)
1352
        #exit()
1353
        #docs = docs[:,0,:]
1354
        #docs = docs.unsqueeze(0)
1355
        #docs = docs.expand(batch_size, -1, -1)
1356

1357
        #################
1358
        #################
1359
        #alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1360
        alpha = float(1/args.num_train_epochs)
1361

1362
        #print(docs.shape)
1363
        #([1000000, 13, 768]) -> yelp
1364
        #([5474, 1, 768]) -> open domain
1365

1366
        #Test
1367
        #docs = load_GeneralDomain_docs("data/open_domain_preprocessed_roberta/")
1368
        #
1369
        '''
1370
        docs = load_GeneralDomain_docs("data/yelp/")
1371
        if docs.shape[1]!=1: #UnboundLocalError: local variable 'docs' referenced before assignment
1372
            docs = docs[:,0,:].unsqueeze(1)
1373
            print(docs.shape)
1374
        else:
1375
            print(docs.shape)
1376
        '''
1377
        ### All label rank (first)
1378
        #1.train --> classifier (after 1 epoch)
1379
        #2.all_label in batch_size --> the same number as batch_size
1380
        #3.reduce variable
1381
        #4.
1382

1383

1384
        k=8
1385
        #k=16
1386
        all_type_sentence_label = list()
1387
        all_previous_sentence_label = list()
1388
        all_type_sentiment_label = list()
1389
        all_previous_sentiment_label = list()
1390
        top_k_all_type = dict()
1391
        bottom_k_all_type = dict()
1392
        for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1393
            tr_loss = 0
1394
            nb_tr_examples, nb_tr_steps = 0, 0
1395
            for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1396

1397
                #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = batch_
1398
                #print(input_ids_.shape)
1399
                #exit()
1400

1401
                #######################
1402
                ######################
1403
                ###Init 8 type sentence
1404
                ###Init 2 type sentiment
1405
                if (step == 0) and (epo == 0):
1406
                    #batch_ = tuple(t.to(device) for t in batch_)
1407
                    #all_type_sentence_ = tuple(t.to(device) for t in all_type_sentence)
1408

1409
                    input_ids_ = torch.stack([line[0] for line in all_type_sentence]).to(device)
1410
                    input_ids_org_ = torch.stack([line[1] for line in all_type_sentence]).to(device)
1411
                    input_mask_ = torch.stack([line[2] for line in all_type_sentence]).to(device)
1412
                    segment_ids_ = torch.stack([line[3] for line in all_type_sentence]).to(device)
1413
                    lm_label_ids_ = torch.stack([line[4] for line in all_type_sentence]).to(device)
1414
                    is_next_ = torch.stack([line[5] for line in all_type_sentence]).to(device)
1415
                    tail_idxs_ = torch.stack([line[6] for line in all_type_sentence]).to(device)
1416
                    sentence_label_ = torch.stack([line[7] for line in all_type_sentence]).to(device)
1417
                    sentiment_label_ = torch.stack([line[8] for line in all_type_sentence]).to(device)
1418

1419
                    #print(sentence_label_)
1420
                    #print(sentiment_label_)
1421
                    #exit()
1422

1423
                    with torch.no_grad():
1424
                        '''
1425
                        in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1426
                        # Search id from Docs and ranking via (Domain/Task)
1427
                        query_domain = in_domain_rep.float().to("cpu")
1428
                        query_domain = query_domain.unsqueeze(1)
1429
                        query_task = in_task_rep.float().to("cpu")
1430
                        query_task = query_task.unsqueeze(1)
1431
                        '''
1432

1433
                        in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1434
                        # Search id from Docs and ranking via (Domain/Task)
1435
                        query_domain = in_domain_rep_mean.float().to("cpu")
1436
                        query_domain = query_domain.unsqueeze(1)
1437
                        query_task = in_task_rep_mean.float().to("cpu")
1438
                        query_task = query_task.unsqueeze(1)
1439

1440
                        ######Attend to a certain layer
1441
                        '''
1442
                        results_domain = torch.matmul(query_domain,docs[-1,:,:].T)
1443
                        results_task = torch.matmul(query_task,docs[-1,:,:].T)
1444
                        '''
1445
                        ######
1446
                        ######Attend to all 13 layers
1447
                        '''
1448
                        start = time.time()
1449
                        results_domain = torch.matmul(docs, query_domain.transpose(0,1))
1450
                        domain_attention =
1451
                        results_domain = results_domain.transpose(1,2).transpose(0,1).sum(2)
1452
                        results_task = torch.matmul(docs, query_task.transpose(0,1))
1453
                        task_attention =
1454
                        results_task = results_task.transpose(1,2).transpose(0,1).sum(2)
1455
                        end = time.time()
1456
                        print("Time:", (end-start)/60)
1457
                        '''
1458

1459
                        #start = time.time()
1460
                        #docs: [batch_size, 1000000, 768]
1461
                        #query: [batch_size, 1, 768]
1462

1463
                        #docs = docs[:,0,:]
1464
                        #docs = docs.unsqueeze(0)
1465
                        #docs = docs.expand(batch_size, -1, -1)
1466

1467
                        task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1468
                        task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1469
                        task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1470
                        task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1471

1472

1473
                        domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1474
                        domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1475
                        domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1476
                        domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768, 2)
1477

1478
                        #start = time.time()
1479
                        query_domain = query_domain.expand(-1, docs.shape[0], -1)
1480
                        query_task = query_domain.expand(-1, docs.shape[0], -1)
1481

1482
                        #################
1483
                        #################
1484
                        #Ranking
1485

1486
                        LeakyReLU = torch.nn.LeakyReLU()
1487
                        #Domain logit
1488
                        domain_binary_logit = LeakyReLU(domain_binary_classifier(docs))
1489
                        domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1490
                        #domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentence_label_.shape[0], -1)
1491
                        domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1492
                        #Task logit
1493
                        #task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentence_label_.shape[0], -1, -1)], dim=2)))
1494
                        task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2)))
1495
                        task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1496

1497
                        #end = time.time()
1498
                        #print("Time:", (end-start)/60)
1499
                        ######
1500
                        results_all_type = domain_binary_logit + task_binary_logit
1501
                        del domain_binary_logit, task_binary_logit
1502
                        bottom_k_all_type = torch.topk(results_all_type, k, dim=1, largest=False, sorted=False)
1503
                        top_k_all_type = torch.topk(results_all_type, k, dim=1, largest=True, sorted=False)
1504
                        del results_all_type
1505
                        #all_type_sentence_label = sentence_label_.to('cpu')
1506
                        all_type_sentiment_label = sentiment_label_.to('cpu')
1507
                        #print("--")
1508
                        #print(bottom_k_all_type.values)
1509
                        #print("--")
1510
                        #exit()
1511
                        bottom_k_all_type = {"values":bottom_k_all_type.values, "indices":bottom_k_all_type.indices}
1512
                        top_k_all_type = {"values":top_k_all_type.values, "indices":top_k_all_type.indices}
1513

1514
                ######################
1515
                ######################
1516

1517

1518

1519
                ###Normal mode
1520
                batch_ = tuple(t.to(device) for t in batch_)
1521
                #input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = batch_
1522
                input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_, sentiment_label_ = batch_
1523

1524

1525
                ###
1526
                # Generate query representation
1527
                in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1528

1529
                in_domain_rep_mean, in_task_rep_mean = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep_mean")
1530

1531
                #if (step%10 == 0) or (sentence_label_.shape[0] != args.train_batch_size):
1532
                if (step%10 == 0) or (sentiment_label_.shape[0] != args.train_batch_size):
1533
                    with torch.no_grad():
1534
                        # Search id from Docs and ranking via (Domain/Task)
1535
                        query_domain = in_domain_rep.float().to("cpu")
1536
                        query_domain = query_domain.unsqueeze(1)
1537
                        query_task = in_task_rep.float().to("cpu")
1538
                        query_task = query_task.unsqueeze(1)
1539

1540
                        ######Attend to a certain layer
1541
                        '''
1542
                        results_domain = torch.matmul(query_domain,docs[-1,:,:].T)
1543
                        results_task = torch.matmul(query_task,docs[-1,:,:].T)
1544
                        '''
1545
                        ######
1546
                        ######Attend to all 13 layers
1547
                        '''
1548
                        start = time.time()
1549
                        results_domain = torch.matmul(docs, query_domain.transpose(0,1))
1550
                        domain_attention =
1551
                        results_domain = results_domain.transpose(1,2).transpose(0,1).sum(2)
1552
                        results_task = torch.matmul(docs, query_task.transpose(0,1))
1553
                        task_attention =
1554
                        results_task = results_task.transpose(1,2).transpose(0,1).sum(2)
1555
                        end = time.time()
1556
                        print("Time:", (end-start)/60)
1557
                        '''
1558

1559
                        #start = time.time()
1560
                        #docs: [batch_size, 1000000, 768]
1561
                        #query: [batch_size, 1, 768]
1562

1563
                        #docs = docs[:,0,:]
1564
                        #docs = docs.unsqueeze(0)
1565
                        #docs = docs.expand(batch_size, -1, -1)
1566

1567
                        task_binary_classifier_weight, task_binary_classifier_bias = model(func="return_task_binary_classifier")
1568
                        task_binary_classifier_weight = task_binary_classifier_weight[:int(task_binary_classifier_weight.shape[0]/n_gpu)][:]
1569
                        task_binary_classifier_bias = task_binary_classifier_bias[:int(task_binary_classifier_bias.shape[0]/n_gpu)][:]
1570
                        task_binary_classifier = return_Classifier(task_binary_classifier_weight, task_binary_classifier_bias, 768*2, 2)
1571

1572

1573
                        domain_binary_classifier_weight, domain_binary_classifier_bias = model(func="return_domain_binary_classifier")
1574
                        domain_binary_classifier_weight = domain_binary_classifier_weight[:int(domain_binary_classifier_weight.shape[0]/n_gpu)][:]
1575
                        domain_binary_classifier_bias = domain_binary_classifier_bias[:int(domain_binary_classifier_bias.shape[0]/n_gpu)][:]
1576
                        domain_binary_classifier = return_Classifier(domain_binary_classifier_weight, domain_binary_classifier_bias, 768, 2)
1577

1578
                        #start = time.time()
1579
                        query_domain = query_domain.expand(-1, docs.shape[0], -1)
1580
                        query_task = query_domain.expand(-1, docs.shape[0], -1)
1581

1582
                        #################
1583
                        #################
1584
                        #Ranking
1585

1586
                        LeakyReLU = torch.nn.LeakyReLU()
1587
                        #Domain logit
1588
                        domain_binary_logit = LeakyReLU(domain_binary_classifier(docs))
1589
                        domain_binary_logit = domain_binary_logit[:,:,1] - domain_binary_logit[:,:,0]
1590
                        #domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentence_label_.shape[0], -1)
1591
                        domain_binary_logit = domain_binary_logit.squeeze(1).unsqueeze(0).expand(sentiment_label_.shape[0], -1)
1592
                        #Task logit
1593
                        #task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentence_label_.shape[0], -1, -1)], dim=2)))
1594
                        task_binary_logit = LeakyReLU(task_binary_classifier(torch.cat([query_task, docs[:,0,:].unsqueeze(0).expand(sentiment_label_.shape[0], -1, -1)], dim=2)))
1595
                        task_binary_logit = task_binary_logit[:,:,1] - task_binary_logit[:,:,0]
1596

1597
                        #end = time.time()
1598
                        #print("Time:", (end-start)/60)
1599
                        ######
1600
                        results = domain_binary_logit + task_binary_logit
1601
                        del domain_binary_logit, task_binary_logit
1602
                        bottom_k = torch.topk(results, k, dim=1, largest=False, sorted=False)
1603
                        bottom_k = {"values":bottom_k.values, "indices":bottom_k.indices}
1604
                        top_k = torch.topk(results, k, dim=1, largest=True, sorted=False)
1605
                        top_k = {"values":top_k.values, "indices":top_k.indices}
1606
                        del results
1607

1608
                        #all_previous_sentence_label = sentence_label_.to('cpu')
1609
                        all_previous_sentiment_label = sentiment_label_.to('cpu')
1610

1611
                        #print(bottom_k.values)
1612
                        #print(bottom_k["values"])
1613
                        #print("==")
1614
                        #print(bottom_k_all_type.values)
1615
                        #exit()
1616

1617
                        #print(torch.cat((bottom_k.values, bottom_k_all_type.values)))
1618
                        #exit()
1619
                        bottom_k_previous = {"values":torch.cat((bottom_k["values"], bottom_k_all_type["values"]),0), "indices":torch.cat((bottom_k["indices"], bottom_k_all_type["indices"]),0)}
1620
                        top_k_previous = {"values":torch.cat((top_k["values"], top_k_all_type["values"]),0), "indices":torch.cat((top_k["indices"], top_k_all_type["indices"]),0)}
1621
                        #all_previous_sentence_label = torch.cat((all_previous_sentence_label, all_type_sentence_label))
1622
                        #print("=====")
1623
                        #print(all_previous_sentiment_label.shape)
1624
                        #print(all_type_sentiment_label.shape)
1625
                        #print("=====")
1626
                        #exit()
1627
                        all_previous_sentiment_label = torch.cat((all_previous_sentiment_label, all_type_sentiment_label))
1628

1629
                else:
1630
                    #print("all_type_sentence_label",all_type_sentence_label) #fix
1631
                    #print("all_previous_sentence_label",all_previous_sentence_label) #prev
1632
                    #print("sentence_label_",sentence_label_) #present
1633

1634
                    #used_idx = torch.tensor([random.choice(((all_previous_sentence_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentence_label_])
1635
                    used_idx = torch.tensor([random.choice(((all_previous_sentiment_label==int(idx_)).nonzero()).tolist())[0] for idx_ in sentiment_label_])
1636
                    top_k = {"values":top_k_previous["values"].index_select(0,used_idx), "indices":top_k_previous["indices"].index_select(0,used_idx)}
1637

1638
                    bottom_k = {"values":bottom_k_previous["values"].index_select(0,used_idx), "indices":bottom_k_previous["indices"].index_select(0,used_idx)}
1639
                    #random.choice(((all_previous_sentence_label==id_).nonzero()).tolist()[0]) for id_ in
1640

1641

1642
                #################
1643
                #################
1644
                #Train Domain Binary Classifier
1645
                #Domain
1646
                #pos: n ; neg:k*n
1647
                #bottom_k
1648
                #Use sample!!! at first
1649
                #bottom_k = torch.topk(results, k, dim=1, largest=False, sorted=False)
1650
                batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1651
                batch = tuple(t.to(device) for t in batch)
1652
                input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs = batch
1653
                #domain_binary_loss, domain_binary_logit = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_binary_classifier", in_domain_rep=in_domain_rep.to(device))
1654
                domain_binary_loss, domain_binary_logit = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_binary_classifier_mean", in_domain_rep=in_domain_rep_mean.to(device))
1655

1656
                #################
1657
                #################
1658
                #Train Task Binary Classifier
1659
                #Pseudo Task --> Won't bp to PLM: only train classifier [In domain data]
1660
                #batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep)
1661
                batch = AugmentationData_Task_pos_and_neg(top_k=None, tokenizer=tokenizer, max_seq_length=args.max_seq_length, add_org=batch_, in_task_rep=in_task_rep_mean)
1662
                batch = tuple(t.to(device) for t in batch)
1663
                all_in_task_rep_comb, all_sentence_binary_label = batch
1664
                #task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier")
1665
                task_binary_loss, task_binary_logit = model(all_in_task_rep_comb=all_in_task_rep_comb, all_sentence_binary_label=all_sentence_binary_label, func="task_binary_classifier_mean")
1666

1667

1668
                #################
1669
                #################
1670
                #Only train Task classifier
1671
                #Task
1672
                #top_k
1673
                #top_k = torch.topk(results, k, dim=1, largest=True, sorted=False)
1674
                batch = AugmentationData_Task(top_k, tokenizer, args.max_seq_length, add_org=batch_)
1675
                batch = tuple(t.to(device) for t in batch)
1676
                #input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label = batch
1677
                input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label, sentiment_label = batch
1678
                #split into: in_dom and query_  --> different weight
1679
                #task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentence_label_, attention_mask=input_mask_, func="task_class")
1680
                task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentiment_label_, attention_mask=input_mask_, func="task_class")
1681

1682
                if epo > 2:
1683
                    #task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentence_label, attention_mask=input_mask, func="task_class")
1684
                    task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentiment_label, attention_mask=input_mask, func="task_class")
1685
                else:
1686
                    task_loss_query = torch.tensor([0.0])
1687

1688
                #loss = task_loss_org + (task_loss_query*alpha*epo*step)/k
1689
                #loss = domain_binary_loss + task_binary_loss + task_loss_org + (task_loss_query*alpha*epo*step)/k
1690

1691
                ##############################
1692
                ##############################
1693

1694
                if n_gpu > 1:
1695
                    #loss = loss.mean() # mean() to average on multi-gpu.
1696
                    #loss = domain_binary_loss.mean() + task_binary_loss.mean() + task_loss_org.mean() + (task_loss_query.mean()*alpha*epo*step)/k
1697

1698
                    #pseudo = (task_loss_query.mean()*alpha*epo*step)/k
1699
                    #pseudo = (task_loss_query.mean()*alpha*epo*step)
1700
                    pseudo = (task_loss_query.mean()*alpha*epo)
1701
                    loss = domain_binary_loss.mean() + task_binary_loss.mean() + task_loss_org.mean() + pseudo
1702

1703
                if args.gradient_accumulation_steps > 1:
1704
                    loss = loss / args.gradient_accumulation_steps
1705
                if args.fp16:
1706
                    #optimizer.backward(loss)
1707
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
1708
                        scaled_loss.backward()
1709
                else:
1710
                    loss.backward()
1711

1712
                ###
1713
                loss_fout.write("{}\n".format(loss.item()))
1714
                ###
1715

1716
                ###
1717
                loss_fout_no_pseudo.write("{}\n".format(loss.item()-pseudo.item()))
1718
                ###
1719

1720
                tr_loss += loss.item()
1721
                #nb_tr_examples += input_ids.size(0)
1722
                nb_tr_examples += input_ids_.size(0)
1723
                nb_tr_steps += 1
1724
                if (step + 1) % args.gradient_accumulation_steps == 0:
1725
                    if args.fp16:
1726
                        # modify learning rate with special warm up BERT uses
1727
                        # if args.fp16 is False, BertAdam is used that handles this automatically
1728
                        #lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1729
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1730
                    ###
1731
                    else:
1732
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1733
                    ###
1734

1735
                    optimizer.step()
1736
                    ###
1737
                    scheduler.step()
1738
                    ###
1739
                    #optimizer.zero_grad()
1740
                    model.zero_grad()
1741
                    global_step += 1
1742

1743

1744

1745
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1746
            output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1747
            torch.save(model_to_save.state_dict(), output_model_file)
1748
            ####
1749
            '''
1750
            #if args.num_train_epochs/args.augment_times in [1,2,3]:
1751
            if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1752
                model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1753
                output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1754
                torch.save(model_to_save.state_dict(), output_model_file)
1755
            '''
1756
            ####
1757

1758
        loss_fout.close()
1759

1760
        # Save a trained model
1761
        logger.info("** ** * Saving fine - tuned model ** ** * ")
1762
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1763
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1764
        if args.do_train:
1765
            torch.save(model_to_save.state_dict(), output_model_file)
1766

1767

1768

1769
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1770
    """Truncates a sequence pair in place to the maximum length."""
1771

1772
    # This is a simple heuristic which will always truncate the longer sequence
1773
    # one token at a time. This makes more sense than truncating an equal percent
1774
    # of tokens from each, since if one sequence is very short then each token
1775
    # that's truncated likely contains more information than a longer sequence.
1776
    while True:
1777
        #total_length = len(tokens_a) + len(tokens_b)
1778
        total_length = len(tokens_a)
1779
        if total_length <= max_length:
1780
            break
1781
        else:
1782
            tokens_a.pop()
1783

1784

1785
def accuracy(out, labels):
1786
    outputs = np.argmax(out, axis=1)
1787
    return np.sum(outputs == labels)
1788

1789

1790
if __name__ == "__main__":
1791
    main()
1792

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.