CSS-LM

Форк
0
/
pretrain_roberta_including_Preprocess_DomainTask_org.py 
1176 строк · 48.5 Кб
1
# coding=utf-8
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
17

18
from __future__ import absolute_import, division, print_function, unicode_literals
19

20
import argparse
21
import logging
22
import os
23
import random
24
from io import open
25
import json
26
import time
27

28
import numpy as np
29
import torch
30
from torch.utils.data import DataLoader, Dataset, RandomSampler
31
from torch.utils.data.distributed import DistributedSampler
32
from tqdm import tqdm, trange
33

34
from transformers import RobertaTokenizer, RobertaForMaskedLM, RobertaForSequenceClassification
35
from transformers.modeling_roberta import RobertaForMaskedLMDomainTask
36
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
37

38
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
39
                    datefmt='%m/%d/%Y %H:%M:%S',
40
                    level=logging.INFO)
41
logger = logging.getLogger(__name__)
42

43
def load_GeneralDomain(dir_data_out):
44
    print("===========")
45
    print("Load CLS.pt and train.json")
46
    print("-----------")
47
    ###
48
    '''
49
    docs = torch.load(dir_data_out+"opendomain_CLS.pt")
50
    with open(dir_data_out+"opendomain.json") as file:
51
    '''
52
    docs = torch.load(dir_data_out+"train_CLS.pt")
53
    print("CLS.pt Done")
54
    print("-----------")
55
    with open(dir_data_out+"train.json") as file:
56
    ###
57
        data = json.load(file)
58
    print("train.json Done")
59
    print("===========")
60
    return docs, data
61

62
#Load outDomainData
63
###
64
#docs, data = load_GeneralDomain("data/open_domain_preprocessed_roberta/")
65
###
66
docs, data = load_GeneralDomain("data/yelp/")
67

68
def in_Domain_Task_Data_mutiple(data_dir_indomain, tokenizer, max_seq_length):
69
    ###Open
70
    with open(data_dir_indomain+"train.json") as file:
71
        data = json.load(file)
72

73
    ###Preprocess
74
    num_label_list = list()
75
    label_sentence_dict = dict()
76
    for line in data:
77
        #line["sentence"]
78
        #line["aspect"]
79
        #line["sentiment"]
80
        num_label_list.append(line["aspect"])
81

82
    num_label = sorted(list(set(num_label_list)))
83
    label_map = {label : i for i , label in enumerate(num_label)}
84
    print("=======")
85
    print("label_map:")
86
    print(label_map)
87
    print("=======")
88

89
    ###Create data: 1 choosed data along with the rest of 7 class data
90

91
    '''
92
    all_input_ids = list()
93
    all_input_mask = list()
94
    all_segment_ids = list()
95
    all_lm_labels_ids = list()
96
    all_is_next = list()
97
    all_tail_idxs = list()
98
    all_sentence_labels = list()
99
    '''
100
    cur_tensors_list = list()
101
    for line in data:
102
        #line["sentence"]
103
        #line["aspect"]
104
        #line["sentiment"]
105
        sentence = line["sentence"]
106
        label = line["aspect"]
107

108

109
        tokens_a = tokenizer.tokenize(sentence)
110
        #input_ids = tokenizer.encode(sentence, add_special_tokens=False)
111
        '''
112
        if "</s>" in tokens_a:
113
            print("Have more than 1 </s>")
114
            #tokens_a[tokens_a.index("<s>")] = "s"
115
            for i in range(len(tokens_a)):
116
                if tokens_a[i] == "</s>":
117
                    tokens_a[i] == "s"
118
        '''
119

120

121
        # tokenize
122
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
123
        # transform sample to features
124
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
125

126
        cur_tensors = (torch.tensor(cur_features.input_ids),
127
                       torch.tensor(cur_features.input_ids_org),
128
                       torch.tensor(cur_features.input_mask),
129
                       torch.tensor(cur_features.segment_ids),
130
                       torch.tensor(cur_features.lm_label_ids),
131
                       torch.tensor(0),
132
                       torch.tensor(cur_features.tail_idxs),
133
                       torch.tensor(label_map[label]))
134

135
        cur_tensors_list.append(cur_tensors)
136

137
    '''
138
        all_input_ids.append(torch.tensor(cur_features.input_ids))
139
        all_input_mask.append(torch.tensor(cur_features.input_mask))
140
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
141
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
142
        all_is_next.append(torch.tensor(0))
143
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
144
        all_sentence_labels.append(torch.tensor(label_map[label]))
145

146
    cur_tensors = (torch.stack(all_input_ids),
147
                   torch.stack(all_input_mask),
148
                   torch.stack(all_segment_ids),
149
                   torch.stack(all_lm_labels_ids),
150
                   torch.stack(all_is_next),
151
                   torch.stack(all_tail_idxs),
152
                   torch.stack(all_sentence_labels))
153
    '''
154

155
    #return cur_tensors
156
    return cur_tensors_list
157

158

159
def in_Domain_Task_Data_binary(data_dir_indomain, tokenizer, max_seq_length):
160
    ###Open
161
    with open(data_dir_indomain+"train.json") as file:
162
        data = json.load(file)
163

164
    ###Preprocess
165
    num_label_list = list()
166
    label_sentence_dict = dict()
167
    for line in data:
168
        #line["sentence"]
169
        #line["aspect"]
170
        #line["sentiment"]
171
        num_label_list.append(line["aspect"])
172
        try:
173
            label_sentence_dict[line["aspect"]].append([line["sentence"]])
174
        except:
175
            label_sentence_dict[line["aspect"]] = [line["sentence"]]
176

177
    num_label = sorted(list(set(num_label_list)))
178
    label_map = {label : i for i , label in enumerate(num_label)}
179

180
    ###Create data: 1 choosed data along with the rest of 7 class data
181
    all_cur_tensors = list()
182
    for line in data:
183
        #line["sentence"]
184
        #line["aspect"]
185
        #line["sentiment"]
186
        sentence = line["sentence"]
187
        label = line["aspect"]
188
        sentence_out = [(random.choice(label_sentence_dict[label_out])[0], label_out) for label_out in num_label if label_out!=label]
189
        all_sentence = [(sentence, label)] + sentence_out #1st sentence is choosed
190

191
        all_input_ids = list()
192
        all_input_mask = list()
193
        all_segment_ids = list()
194
        all_lm_labels_ids = list()
195
        all_is_next = list()
196
        all_tail_idxs = list()
197
        all_sentence_labels = list()
198
        for id, sentence_label in enumerate(all_sentence):
199
            #tokens_a = tokenizer.tokenize(sentence_label[0])
200
            tokens_a = tokenizer.tokenize(sentence_label[0])
201
            '''
202
            if "</s>" in tokens_a:
203
                print("Have more than 1 </s>")
204
                for i in range(len(tokens_a)):
205
                    if tokens_a[i] == "</s>":
206
                        tokens_a[i] = "s"
207
            '''
208

209
            # tokenize
210
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
211
            # transform sample to features
212
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
213

214
            all_input_ids.append(torch.tensor(cur_features.input_ids))
215
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
216
            all_input_mask.append(torch.tensor(cur_features.input_mask))
217
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
218
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
219
            all_is_next.append(torch.tensor(0))
220
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
221
            all_sentence_labels.append(torch.tensor(label_map[sentence_label[1]]))
222

223
        cur_tensors = (torch.stack(all_input_ids),
224
                       torch.stack(all_input_ids_org),
225
                       torch.stack(all_input_mask),
226
                       torch.stack(all_segment_ids),
227
                       torch.stack(all_lm_labels_ids),
228
                       torch.stack(all_is_next),
229
                       torch.stack(all_tail_idxs),
230
                       torch.stack(all_sentence_labels))
231

232
        all_cur_tensors.append(cur_tensors)
233

234
    return all_cur_tensors
235

236

237

238
def AugmentationData_Domain(top_k, tokenizer, max_seq_length):
239
    top_k_shape = top_k.indices.shape
240
    ids = top_k.indices.reshape(top_k_shape[0]*top_k_shape[1]).tolist()
241

242
    all_input_ids = list()
243
    all_input_ids_org = list()
244
    all_input_mask = list()
245
    all_segment_ids = list()
246
    all_lm_labels_ids = list()
247
    all_is_next = list()
248
    all_tail_idxs = list()
249

250
    for id, i in enumerate(ids):
251
        t1 = data[str(i)]['sentence']
252

253
        #tokens_a = tokenizer.tokenize(t1)
254
        tokens_a = tokenizer.tokenize(t1)
255
        '''
256
        if "</s>" in tokens_a:
257
            print("Have more than 1 </s>")
258
            #tokens_a[tokens_a.index("<s>")] = "s"
259
            for i in range(len(tokens_a)):
260
                if tokens_a[i] == "</s>":
261
                    tokens_a[i] = "s"
262
        '''
263

264
        # tokenize
265
        cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
266

267
        # transform sample to features
268
        cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
269

270
        all_input_ids.append(torch.tensor(cur_features.input_ids))
271
        all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
272
        all_input_mask.append(torch.tensor(cur_features.input_mask))
273
        all_segment_ids.append(torch.tensor(cur_features.segment_ids))
274
        all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
275
        all_is_next.append(torch.tensor(0))
276
        all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
277

278

279
    cur_tensors = (torch.stack(all_input_ids),
280
                   torch.stack(all_input_ids_org),
281
                   torch.stack(all_input_mask),
282
                   torch.stack(all_segment_ids),
283
                   torch.stack(all_lm_labels_ids),
284
                   torch.stack(all_is_next),
285
                   torch.stack(all_tail_idxs))
286

287
    return cur_tensors
288

289

290

291
def AugmentationData_Task(top_k, tokenizer, max_seq_length, add_org=None):
292
    top_k_shape = top_k.indices.shape
293
    sentence_ids = top_k.indices
294

295
    all_input_ids = list()
296
    all_input_ids_org = list()
297
    all_input_mask = list()
298
    all_segment_ids = list()
299
    all_lm_labels_ids = list()
300
    all_is_next = list()
301
    all_tail_idxs = list()
302
    all_sentence_labels = list()
303

304
    input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = add_org
305

306
    for id_1, sent  in enumerate(sentence_ids):
307
        for id_2, sent_id  in enumerate(sent):
308

309
            t1 = data[str(int(sent_id))]['sentence']
310

311
            #tokens_a = tokenizer.tokenize(t1)
312
            tokens_a = tokenizer.tokenize(t1)
313
            '''
314
            if "</s>" in tokens_a:
315
                print("Have more than 1 </s>")
316
                #tokens_a[tokens_a.index("<s>")] = "s"
317
                for i in range(len(tokens_a)):
318
                    if tokens_a[i] == "</s>":
319
                        tokens_a[i] = "s"
320
            '''
321

322

323
            # tokenize
324
            cur_example = InputExample(guid=id, tokens_a=tokens_a, tokens_b=None, is_next=0)
325

326
            # transform sample to features
327
            cur_features = convert_example_to_features(cur_example, max_seq_length, tokenizer)
328

329
            all_input_ids.append(torch.tensor(cur_features.input_ids))
330
            all_input_ids_org.append(torch.tensor(cur_features.input_ids_org))
331
            all_input_mask.append(torch.tensor(cur_features.input_mask))
332
            all_segment_ids.append(torch.tensor(cur_features.segment_ids))
333
            all_lm_labels_ids.append(torch.tensor(cur_features.lm_label_ids))
334
            all_is_next.append(torch.tensor(0))
335
            all_tail_idxs.append(torch.tensor(cur_features.tail_idxs))
336
            all_sentence_labels.append(torch.tensor(sentence_label_[id_1]))
337

338
        all_input_ids.append(input_ids_[id_1])
339
        all_input_ids_org.append(input_ids_org_[id_1])
340
        all_input_mask.append(input_mask_[id_1])
341
        all_segment_ids.append(segment_ids_[id_1])
342
        all_lm_labels_ids.append(lm_label_ids_[id_1])
343
        all_is_next.append(is_next_[id_1])
344
        all_tail_idxs.append(tail_idxs_[id_1])
345
        all_sentence_labels.append(sentence_label_[id_1])
346

347

348
    cur_tensors = (torch.stack(all_input_ids),
349
                   torch.stack(all_input_ids_org),
350
                   torch.stack(all_input_mask),
351
                   torch.stack(all_segment_ids),
352
                   torch.stack(all_lm_labels_ids),
353
                   torch.stack(all_is_next),
354
                   torch.stack(all_tail_idxs),
355
                   torch.stack(all_sentence_labels)
356
                   )
357

358

359
    return cur_tensors
360

361

362

363
class Dataset_noNext(Dataset):
364
    def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
365

366
        self.vocab_size = tokenizer.vocab_size
367
        self.tokenizer = tokenizer
368
        self.seq_len = seq_len
369
        self.on_memory = on_memory
370
        self.corpus_lines = corpus_lines  # number of non-empty lines in input corpus
371
        self.corpus_path = corpus_path
372
        self.encoding = encoding
373
        self.current_doc = 0  # to avoid random sentence from same doc
374

375
        # for loading samples directly from file
376
        self.sample_counter = 0  # used to keep track of full epochs on file
377
        self.line_buffer = None  # keep second sentence of a pair in memory and use as first sentence in next pair
378

379
        # for loading samples in memory
380
        self.current_random_doc = 0
381
        self.num_docs = 0
382
        self.sample_to_doc = [] # map sample index to doc and line
383

384
        # load samples into memory
385
        if on_memory:
386
            self.all_docs = []
387
            doc = []
388
            self.corpus_lines = 0
389
            with open(corpus_path, "r", encoding=encoding) as f:
390
                for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
391
                    line = line.strip()
392
                    if line == "":
393
                        self.all_docs.append(doc)
394
                        doc = []
395
                        #remove last added sample because there won't be a subsequent line anymore in the doc
396
                        self.sample_to_doc.pop()
397
                    else:
398
                        #store as one sample
399
                        sample = {"doc_id": len(self.all_docs),
400
                                  "line": len(doc)}
401
                        self.sample_to_doc.append(sample)
402
                        doc.append(line)
403
                        self.corpus_lines = self.corpus_lines + 1
404

405
            # if last row in file is not empty
406
            if self.all_docs[-1] != doc:
407
                self.all_docs.append(doc)
408
                self.sample_to_doc.pop()
409

410
            self.num_docs = len(self.all_docs)
411

412
        # load samples later lazily from disk
413
        else:
414
            if self.corpus_lines is None:
415
                with open(corpus_path, "r", encoding=encoding) as f:
416
                    self.corpus_lines = 0
417
                    for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
418
                        if line.strip() == "":
419
                            self.num_docs += 1
420
                        else:
421
                            self.corpus_lines += 1
422

423
                    # if doc does not end with empty line
424
                    if line.strip() != "":
425
                        self.num_docs += 1
426

427
            self.file = open(corpus_path, "r", encoding=encoding)
428
            self.random_file = open(corpus_path, "r", encoding=encoding)
429

430
    def __len__(self):
431
        # last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
432
        return self.corpus_lines - self.num_docs - 1
433

434
    def __getitem__(self, item):
435
        cur_id = self.sample_counter
436
        self.sample_counter += 1
437
        if not self.on_memory:
438
            # after one epoch we start again from beginning of file
439
            if cur_id != 0 and (cur_id % len(self) == 0):
440
                self.file.close()
441
                self.file = open(self.corpus_path, "r", encoding=self.encoding)
442

443
        #t1, t2, is_next_label = self.random_sent(item)
444
        t1, is_next_label = self.random_sent(item)
445
        if is_next_label == None:
446
            is_next_label = 0
447

448

449
        #tokens_a = self.tokenizer.tokenize(t1)
450
        tokens_a = tokenizer.tokenize(t1)
451
        '''
452
        if "</s>" in tokens_a:
453
            print("Have more than 1 </s>")
454
            #tokens_a[tokens_a.index("<s>")] = "s"
455
            for i in range(len(tokens_a)):
456
                if tokens_a[i] == "</s>":
457
                    tokens_a[i] = "s"
458
        '''
459
        #tokens_b = self.tokenizer.tokenize(t2)
460

461
        # tokenize
462
        cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=None, is_next=is_next_label)
463

464
        # transform sample to features
465
        cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
466

467
        cur_tensors = (torch.tensor(cur_features.input_ids),
468
                       torch.tensor(cur_features.input_ids_org),
469
                       torch.tensor(cur_features.input_mask),
470
                       torch.tensor(cur_features.segment_ids),
471
                       torch.tensor(cur_features.lm_label_ids),
472
                       torch.tensor(cur_features.is_next),
473
                       torch.tensor(cur_features.tail_idxs))
474

475
        return cur_tensors
476

477
    def random_sent(self, index):
478
        """
479
        Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
480
        from one doc. With 50% the second sentence will be a random one from another doc.
481
        :param index: int, index of sample.
482
        :return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
483
        """
484
        t1, t2 = self.get_corpus_line(index)
485
        return t1, None
486

487
    def get_corpus_line(self, item):
488
        """
489
        Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
490
        :param item: int, index of sample.
491
        :return: (str, str), two subsequent sentences from corpus
492
        """
493
        t1 = ""
494
        t2 = ""
495
        assert item < self.corpus_lines
496
        if self.on_memory:
497
            sample = self.sample_to_doc[item]
498
            t1 = self.all_docs[sample["doc_id"]][sample["line"]]
499
            # used later to avoid random nextSentence from same doc
500
            self.current_doc = sample["doc_id"]
501
            return t1, t2
502
            #return t1
503
        else:
504
            if self.line_buffer is None:
505
                # read first non-empty line of file
506
                while t1 == "" :
507
                    t1 = next(self.file).strip()
508
            else:
509
                # use t2 from previous iteration as new t1
510
                t1 = self.line_buffer
511
                # skip empty rows that are used for separating documents and keep track of current doc id
512
                while t1 == "":
513
                    t1 = next(self.file).strip()
514
                    self.current_doc = self.current_doc+1
515
            self.line_buffer = next(self.file).strip()
516

517
        assert t1 != ""
518
        return t1, t2
519

520

521
    def get_random_line(self):
522
        """
523
        Get random line from another document for nextSentence task.
524
        :return: str, content of one line
525
        """
526
        # Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
527
        # corpora. However, just to be careful, we try to make sure that
528
        # the random document is not the same as the document we're processing.
529
        for _ in range(10):
530
            if self.on_memory:
531
                rand_doc_idx = random.randint(0, len(self.all_docs)-1)
532
                rand_doc = self.all_docs[rand_doc_idx]
533
                line = rand_doc[random.randrange(len(rand_doc))]
534
            else:
535
                rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
536
                #pick random line
537
                for _ in range(rand_index):
538
                    line = self.get_next_line()
539
            #check if our picked random line is really from another doc like we want it to be
540
            if self.current_random_doc != self.current_doc:
541
                break
542
        return line
543

544
    def get_next_line(self):
545
        """ Gets next line of random_file and starts over when reaching end of file"""
546
        try:
547
            line = next(self.random_file).strip()
548
            #keep track of which document we are currently looking at to later avoid having the same doc as t1
549
            if line == "":
550
                self.current_random_doc = self.current_random_doc + 1
551
                line = next(self.random_file).strip()
552
        except StopIteration:
553
            self.random_file.close()
554
            self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
555
            line = next(self.random_file).strip()
556
        return line
557

558

559
class InputExample(object):
560
    """A single training/test example for the language model."""
561

562
    def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
563
        """Constructs a InputExample.
564
        Args:
565
            guid: Unique id for the example.
566
            tokens_a: string. The untokenized text of the first sequence. For single
567
            sequence tasks, only this sequence must be specified.
568
            tokens_b: (Optional) string. The untokenized text of the second sequence.
569
            Only must be specified for sequence pair tasks.
570
            label: (Optional) string. The label of the example. This should be
571
            specified for train and dev examples, but not for test examples.
572
        """
573
        self.guid = guid
574
        self.tokens_a = tokens_a
575
        self.tokens_b = tokens_b
576
        self.is_next = is_next  # nextSentence
577
        self.lm_labels = lm_labels  # masked words for language model
578

579

580
class InputFeatures(object):
581
    """A single set of features of data."""
582

583
    def __init__(self, input_ids, input_ids_org, input_mask, segment_ids, is_next, lm_label_ids, tail_idxs):
584
        self.input_ids = input_ids
585
        self.input_ids_org = input_ids_org
586
        self.input_mask = input_mask
587
        self.segment_ids = segment_ids
588
        self.is_next = is_next
589
        self.lm_label_ids = lm_label_ids
590
        self.tail_idxs = tail_idxs
591

592

593
def random_word(tokens, tokenizer):
594
    """
595
    Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
596
    :param tokens: list of str, tokenized sentence.
597
    :param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
598
    :return: (list of str, list of int), masked tokens and related labels for LM prediction
599
    """
600
    output_label = []
601

602
    for i, token in enumerate(tokens):
603

604
        prob = random.random()
605
        # mask token with 15% probability
606
        if prob < 0.15:
607
            prob /= 0.15
608
            #candidate_id = random.randint(0,tokenizer.vocab_size)
609
            #print(tokenizer.convert_ids_to_tokens(candidate_id))
610

611

612
            # 80% randomly change token to mask token
613
            if prob < 0.8:
614
                #tokens[i] = "[MASK]"
615
                tokens[i] = "<mask>"
616

617
            # 10% randomly change token to random token
618
            elif prob < 0.9:
619
                #tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
620
                #tokens[i] = tokenizer.convert_ids_to_tokens(candidate_id)
621
                candidate_id = random.randint(0,tokenizer.vocab_size)
622
                w = tokenizer.convert_ids_to_tokens(candidate_id)
623
                '''
624
                if tokens[i] == None:
625
                    candidate_id = 100
626
                    w = tokenizer.convert_ids_to_tokens(candidate_id)
627
                '''
628
                tokens[i] = w
629

630

631
            # -> rest 10% randomly keep current token
632

633
            # append current token to output (we will predict these later)
634
            try:
635
                #output_label.append(tokenizer.vocab[token])
636
                w = tokenizer.convert_tokens_to_ids(token)
637
                if w!= None:
638
                    output_label.append(w)
639
                else:
640
                    print("Have no this tokens in ids")
641
                    exit()
642
            except KeyError:
643
                # For unknown words (should not occur with BPE vocab)
644
                #output_label.append(tokenizer.vocab["<unk>"])
645
                w = tokenizer.convert_tokens_to_ids("<unk>")
646
                output_label.append(w)
647
                logger.warning("Cannot find token '{}' in vocab. Using <unk> insetad".format(token))
648
        else:
649
            # no masking token (will be ignored by loss function later)
650
            output_label.append(-1)
651

652
    return tokens, output_label
653

654

655
def convert_example_to_features(example, max_seq_length, tokenizer):
656
    """
657
    Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
658
    IDs, LM labels, input_mask, CLS and SEP tokens etc.
659
    :param example: InputExample, containing sentence input as strings and is_next label
660
    :param max_seq_length: int, maximum length of sequence.
661
    :param tokenizer: Tokenizer
662
    :return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
663
    """
664
    #now tokens_a is input_ids
665
    tokens_a = example.tokens_a
666
    tokens_b = example.tokens_b
667
    # Modifies `tokens_a` and `tokens_b` in place so that the total
668
    # length is less than the specified length.
669
    # Account for [CLS], [SEP], [SEP] with "- 3"
670
    #_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
671
    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 2)
672

673
    #print(tokens_a)
674
    tokens_a_org = tokens_a.copy()
675
    tokens_a, t1_label = random_word(tokens_a, tokenizer)
676
    #print("----")
677
    #print(tokens_a)
678
    #print(tokens_a_org)
679
    #exit()
680
    #print(t1_label)
681
    #exit()
682
    #tokens_b, t2_label = random_word(tokens_b, tokenizer)
683
    # concatenate lm labels and account for CLS, SEP, SEP
684
    #lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
685
    lm_label_ids = ([-1] + t1_label + [-1])
686

687
    # The convention in BERT is:
688
    # (a) For sequence pairs:
689
    #  tokens:   [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
690
    #  type_ids: 0   0  0    0    0     0       0 0    1  1  1  1   1 1
691
    # (b) For single sequences:
692
    #  tokens:   [CLS] the dog is hairy . [SEP]
693
    #  type_ids: 0   0   0   0  0     0 0
694
    #
695
    # Where "type_ids" are used to indicate whether this is the first
696
    # sequence or the second sequence. The embedding vectors for `type=0` and
697
    # `type=1` were learned during pre-training and are added to the wordpiece
698
    # embedding vector (and position vector). This is not *strictly* necessary
699
    # since the [SEP] token unambigiously separates the sequences, but it makes
700
    # it easier for the model to learn the concept of sequences.
701
    #
702
    # For classification tasks, the first vector (corresponding to [CLS]) is
703
    # used as as the "sentence vector". Note that this only makes sense because
704
    # the entire model is fine-tuned.
705
    tokens = []
706
    tokens_org = []
707
    segment_ids = []
708
    tokens.append("<s>")
709
    tokens_org.append("<s>")
710
    segment_ids.append(0)
711
    for i, token in enumerate(tokens_a):
712
        if token!="</s>":
713
            tokens.append(tokens_a[i])
714
            tokens_org.append(tokens_a_org[i])
715
            segment_ids.append(0)
716
        else:
717
            tokens.append("s")
718
            tokens_org.append("s")
719
            segment_ids.append(0)
720
    tokens.append("</s>")
721
    tokens_org.append("</s>")
722
    segment_ids.append(0)
723

724
    #tokens.append("[SEP]")
725
    #segment_ids.append(1)
726

727
    #input_ids = tokenizer.convert_tokens_to_ids(tokens)
728
    input_ids = tokenizer.encode(tokens, add_special_tokens=False)
729
    input_ids_org = tokenizer.encode(tokens_org, add_special_tokens=False)
730
    tail_idxs = len(input_ids)+1
731

732
    #print(input_ids)
733
    input_ids = [w if w!=None else 0 for w in input_ids]
734
    input_ids_org = [w if w!=None else 0 for w in input_ids_org]
735
    #print(input_ids)
736
    #exit()
737

738
    # The mask has 1 for real tokens and 0 for padding tokens. Only real
739
    # tokens are attended to.
740
    input_mask = [1] * len(input_ids)
741

742
    # Zero-pad up to the sequence length.
743
    pad_id = tokenizer.convert_tokens_to_ids("<pad>")
744
    while len(input_ids) < max_seq_length:
745
        input_ids.append(pad_id)
746
        input_ids_org.append(pad_id)
747
        input_mask.append(0)
748
        segment_ids.append(0)
749
        lm_label_ids.append(-1)
750

751

752
    assert len(input_ids) == max_seq_length
753
    assert len(input_ids_org) == max_seq_length
754
    assert len(input_mask) == max_seq_length
755
    assert len(segment_ids) == max_seq_length
756
    assert len(lm_label_ids) == max_seq_length
757

758
    '''
759
    if example.guid < 5:
760
        logger.info("*** Example ***")
761
        logger.info("guid: %s" % (example.guid))
762
        logger.info("tokens: %s" % " ".join(
763
                [str(x) for x in tokens]))
764
        logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
765
        logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
766
        logger.info(
767
                "segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
768
        logger.info("LM label: %s " % (lm_label_ids))
769
        logger.info("Is next sentence label: %s " % (example.is_next))
770
    '''
771

772
    features = InputFeatures(input_ids=input_ids,
773
                             input_ids_org = input_ids_org,
774
                             input_mask=input_mask,
775
                             segment_ids=segment_ids,
776
                             lm_label_ids=lm_label_ids,
777
                             is_next=example.is_next,
778
                             tail_idxs=tail_idxs)
779
    return features
780

781

782
def main():
783
    parser = argparse.ArgumentParser()
784

785
    ## Required parameters
786
    parser.add_argument("--data_dir_indomain",
787
                        default=None,
788
                        type=str,
789
                        required=True,
790
                        help="The input train corpus.(In Domain)")
791
    parser.add_argument("--data_dir_outdomain",
792
                        default=None,
793
                        type=str,
794
                        required=True,
795
                        help="The input train corpus.(Out Domain)")
796
    parser.add_argument("--pretrain_model", default=None, type=str, required=True,
797
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
798
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
799
    parser.add_argument("--output_dir",
800
                        default=None,
801
                        type=str,
802
                        required=True,
803
                        help="The output directory where the model checkpoints will be written.")
804
    parser.add_argument("--augment_times",
805
                        default=None,
806
                        type=int,
807
                        required=True,
808
                        help="Default batch_size/augment_times to save model")
809
    ## Other parameters
810
    parser.add_argument("--max_seq_length",
811
                        default=128,
812
                        type=int,
813
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
814
                             "Sequences longer than this will be truncated, and sequences shorter \n"
815
                             "than this will be padded.")
816
    parser.add_argument("--do_train",
817
                        action='store_true',
818
                        help="Whether to run training.")
819
    parser.add_argument("--train_batch_size",
820
                        default=32,
821
                        type=int,
822
                        help="Total batch size for training.")
823
    parser.add_argument("--learning_rate",
824
                        default=3e-5,
825
                        type=float,
826
                        help="The initial learning rate for Adam.")
827
    parser.add_argument("--num_train_epochs",
828
                        default=3.0,
829
                        type=float,
830
                        help="Total number of training epochs to perform.")
831
    parser.add_argument("--warmup_proportion",
832
                        default=0.1,
833
                        type=float,
834
                        help="Proportion of training to perform linear learning rate warmup for. "
835
                             "E.g., 0.1 = 10%% of training.")
836
    parser.add_argument("--no_cuda",
837
                        action='store_true',
838
                        help="Whether not to use CUDA when available")
839
    parser.add_argument("--on_memory",
840
                        action='store_true',
841
                        help="Whether to load train samples into memory or use disk")
842
    parser.add_argument("--do_lower_case",
843
                        action='store_true',
844
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
845
    parser.add_argument("--local_rank",
846
                        type=int,
847
                        default=-1,
848
                        help="local_rank for distributed training on gpus")
849
    parser.add_argument('--seed',
850
                        type=int,
851
                        default=42,
852
                        help="random seed for initialization")
853
    parser.add_argument('--gradient_accumulation_steps',
854
                        type=int,
855
                        default=1,
856
                        help="Number of updates steps to accumualte before performing a backward/update pass.")
857
    parser.add_argument('--fp16',
858
                        action='store_true',
859
                        help="Whether to use 16-bit float precision instead of 32-bit")
860
    parser.add_argument('--loss_scale',
861
                        type = float, default = 0,
862
                        help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
863
                        "0 (default value): dynamic loss scaling.\n"
864
                        "Positive power of 2: static loss scaling value.\n")
865
    ####
866
    parser.add_argument("--num_labels_task",
867
                        default=None, type=int,
868
                        required=True,
869
                        help="num_labels_task")
870
    parser.add_argument("--weight_decay",
871
                        default=0.0,
872
                        type=float,
873
                        help="Weight decay if we apply some.")
874
    parser.add_argument("--adam_epsilon",
875
                        default=1e-8,
876
                        type=float,
877
                        help="Epsilon for Adam optimizer.")
878
    parser.add_argument("--max_grad_norm",
879
                        default=1.0,
880
                        type=float,
881
                        help="Max gradient norm.")
882
    parser.add_argument('--fp16_opt_level',
883
                        type=str,
884
                        default='O1',
885
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
886
                             "See details at https://nvidia.github.io/apex/amp.html")
887
    parser.add_argument("--task",
888
                        default=None,
889
                        type=int,
890
                        required=True,
891
                        help="Choose Task")
892
    ####
893

894
    args = parser.parse_args()
895

896
    if args.local_rank == -1 or args.no_cuda:
897
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
898
        n_gpu = torch.cuda.device_count()
899
    else:
900
        torch.cuda.set_device(args.local_rank)
901
        device = torch.device("cuda", args.local_rank)
902
        n_gpu = 1
903
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
904
        torch.distributed.init_process_group(backend='nccl')
905
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
906
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
907

908
    if args.gradient_accumulation_steps < 1:
909
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
910
                            args.gradient_accumulation_steps))
911

912
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
913

914
    random.seed(args.seed)
915
    np.random.seed(args.seed)
916
    torch.manual_seed(args.seed)
917
    if n_gpu > 0:
918
        torch.cuda.manual_seed_all(args.seed)
919

920
    if not args.do_train:
921
        raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
922

923
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
924
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
925
    if not os.path.exists(args.output_dir):
926
        os.makedirs(args.output_dir)
927

928
    #tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model, do_lower_case=args.do_lower_case)
929
    tokenizer = RobertaTokenizer.from_pretrained(args.pretrain_model)
930

931

932
    #train_examples = None
933
    num_train_optimization_steps = None
934
    if args.do_train:
935
        print("Loading Train Dataset", args.data_dir_indomain)
936
        #train_dataset = Dataset_noNext(args.data_dir, tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory)
937
        train_dataset = in_Domain_Task_Data_mutiple(args.data_dir_indomain, tokenizer, args.max_seq_length)
938
        num_train_optimization_steps = int(
939
            len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
940
        if args.local_rank != -1:
941
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
942

943

944

945
    # Prepare model
946
    model = RobertaForMaskedLMDomainTask.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
947
    #model = RobertaForSequenceClassification.from_pretrained(args.pretrain_model, output_hidden_states=True, return_dict=True, num_labels=args.num_labels_task)
948
    model.to(device)
949

950

951

952
    # Prepare optimizer
953
    if args.do_train:
954
        param_optimizer = list(model.named_parameters())
955
        '''
956
        for par in param_optimizer:
957
            print(par[0])
958
        exit()
959
        '''
960
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
961
        optimizer_grouped_parameters = [
962
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
963
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
964
            ]
965
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
966
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(num_train_optimization_steps*0.1), num_training_steps=num_train_optimization_steps)
967

968
        if args.fp16:
969
            try:
970
                from apex import amp
971
            except ImportError:
972
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
973
                exit()
974

975
            model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
976

977

978
        if n_gpu > 1:
979
            model = torch.nn.DataParallel(model)
980

981
        if args.local_rank != -1:
982
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
983

984

985
    global_step = 0
986
    if args.do_train:
987
        logger.info("***** Running training *****")
988
        logger.info("  Num examples = %d", len(train_dataset))
989
        logger.info("  Batch size = %d", args.train_batch_size)
990
        logger.info("  Num steps = %d", num_train_optimization_steps)
991

992
        if args.local_rank == -1:
993
            train_sampler = RandomSampler(train_dataset)
994
        else:
995
            #TODO: check if this works with current data generator from disk that relies on next(file)
996
            # (it doesn't return item back by index)
997
            train_sampler = DistributedSampler(train_dataset)
998
        train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
999

1000
        output_loss_file = os.path.join(args.output_dir, "loss")
1001
        loss_fout = open(output_loss_file, 'w')
1002
        model.train()
1003

1004

1005
        ##Need to confirm use input_ids or input_ids_org !!!!!!!!
1006
        alpha = float(1/(args.num_train_epochs*len(train_dataloader)))
1007
        for epo in trange(int(args.num_train_epochs), desc="Epoch"):
1008
            tr_loss = 0
1009
            nb_tr_examples, nb_tr_steps = 0, 0
1010
            for step, batch_ in enumerate(tqdm(train_dataloader, desc="Iteration")):
1011
                batch_ = tuple(t.to(device) for t in batch_)
1012
                #input_ids, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs = batch
1013
                #input_ids, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label = batch
1014
                input_ids_, input_ids_org_, input_mask_, segment_ids_, lm_label_ids_, is_next_, tail_idxs_, sentence_label_ = batch_
1015

1016
                ##############################
1017
                ##############################
1018
                ###
1019
                # Generate query representation
1020
                in_domain_rep, in_task_rep = model(input_ids_org=input_ids_org_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_domain_task_rep")
1021
                #in_task_rep = model(input_ids=input_ids_, masked_lm_labels=lm_label_ids_, tail_idxs=tail_idxs_, attention_mask=input_mask_, func="in_task_rep")
1022

1023
                # Search id from Docs and ranking via (Domain/Task)
1024
                query_domain = in_domain_rep.float().to("cpu")
1025
                query_task = in_task_rep.float().to("cpu")
1026
                ######Attend to a certain layer
1027
                results_domain = torch.matmul(query_domain,docs[-1,:,:].T)
1028
                results_task = torch.matmul(query_task,docs[-1,:,:].T)
1029
                ######
1030
                ######Attend to all 13 layers
1031
                '''
1032
                start = time.time()
1033
                results_domain = torch.matmul(docs, query_domain.transpose(0,1))
1034
                domain_attention =
1035
                results_domain = results_domain.transpose(1,2).transpose(0,1).sum(2)
1036
                results_task = torch.matmul(docs, query_task.transpose(0,1))
1037
                task_attention =
1038
                results_task = results_task.transpose(1,2).transpose(0,1).sum(2)
1039
                end = time.time()
1040
                print("Time:", (end-start)/60)
1041
                '''
1042
                ######
1043

1044
                results = results_domain + results_task
1045
                k=8
1046

1047
                #Domain
1048
                '''
1049
                #pos: n ; neg:k*n
1050
                bottom_k = torch.topk(results, k, dim=1, largest=False, sorted=True)
1051
                batch = AugmentationData_Domain(bottom_k, tokenizer, args.max_seq_length)
1052
                batch = tuple(t.to(device) for t in batch)
1053
                #Only need input_ids
1054
                input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs = batch
1055
                #domain_loss = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_class", in_domain_rep=in_domain_rep.to(device))
1056

1057

1058
                domain_loss = model(input_ids_org=input_ids_org, masked_lm_labels=lm_label_ids, attention_mask=input_mask, func="domain_class", in_domain_rep=in_domain_rep.to(device))
1059
                '''
1060

1061
                #Task
1062
                # Fix: Need to split into org and queried data
1063
                top_k = torch.topk(results, k, dim=1, largest=True, sorted=True)
1064
                batch_ = tuple(t.to("cpu") for t in batch_)
1065
                batch = AugmentationData_Task(top_k, tokenizer, args.max_seq_length, add_org=batch_)
1066
                batch = tuple(t.to(device) for t in batch)
1067
                input_ids, input_ids_org, input_mask, segment_ids, lm_label_ids, is_next, tail_idxs, sentence_label = batch
1068
                '''
1069
                task_loss, mlm_loss = model(input_ids=input_ids, input_ids_org=input_ids_org, sentence_label=sentence_label, lm_label=lm_label_ids, attention_mask=input_mask, func="task_class and mlm", in_domain_rep=in_domain_rep.to(device), batch_size=args.train_batch_size)
1070

1071
                loss = domain_loss + task_loss + mlm_loss
1072
                '''
1073

1074
                #Only train Task loss
1075
                #split into: in_dom and query_  --> different weight
1076
                task_loss_org, class_logit_org = model(input_ids_org=input_ids_org_, sentence_label=sentence_label_, attention_mask=input_mask_, func="task_class")
1077

1078
                task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentence_label, attention_mask=input_mask, func="task_class")
1079
                #task_loss_query, class_logit_query = model(input_ids_org=input_ids_org, sentence_label=sentence_label, attention_mask=input_mask, func="seudo_task")
1080

1081
                loss = task_loss_org + (task_loss_query*alpha*epo*step)/k
1082

1083
                #Classifier
1084
                #task_loss_org = model(input_ids=input_ids_org_, attention_mask=input_mask_,labels=sentence_label_)
1085
                #loss = task_loss_org.loss
1086
                ###
1087
                ##############################
1088
                ##############################
1089

1090
                if n_gpu > 1:
1091
                    loss = loss.mean() # mean() to average on multi-gpu.
1092
                if args.gradient_accumulation_steps > 1:
1093
                    loss = loss / args.gradient_accumulation_steps
1094
                if args.fp16:
1095
                    #optimizer.backward(loss)
1096
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
1097
                        scaled_loss.backward()
1098
                else:
1099
                    loss.backward()
1100

1101
                ###
1102
                loss_fout.write("{}\n".format(loss.item()))
1103
                ###
1104

1105
                tr_loss += loss.item()
1106
                #nb_tr_examples += input_ids.size(0)
1107
                nb_tr_examples += input_ids_.size(0)
1108
                nb_tr_steps += 1
1109
                if (step + 1) % args.gradient_accumulation_steps == 0:
1110
                    if args.fp16:
1111
                        # modify learning rate with special warm up BERT uses
1112
                        # if args.fp16 is False, BertAdam is used that handles this automatically
1113
                        #lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
1114
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
1115
                    ###
1116
                    else:
1117
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
1118
                    ###
1119

1120
                    optimizer.step()
1121
                    ###
1122
                    scheduler.step()
1123
                    ###
1124
                    #optimizer.zero_grad()
1125
                    model.zero_grad()
1126
                    global_step += 1
1127

1128

1129

1130
            model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1131
            output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1132
            torch.save(model_to_save.state_dict(), output_model_file)
1133
            ####
1134
            '''
1135
            #if args.num_train_epochs/args.augment_times in [1,2,3]:
1136
            if (args.num_train_epochs/(args.augment_times/5))%5 == 0:
1137
                model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1138
                output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
1139
                torch.save(model_to_save.state_dict(), output_model_file)
1140
            '''
1141
            ####
1142

1143
        loss_fout.close()
1144

1145
        # Save a trained model
1146
        logger.info("** ** * Saving fine - tuned model ** ** * ")
1147
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
1148
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
1149
        if args.do_train:
1150
            torch.save(model_to_save.state_dict(), output_model_file)
1151

1152

1153

1154
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
1155
    """Truncates a sequence pair in place to the maximum length."""
1156

1157
    # This is a simple heuristic which will always truncate the longer sequence
1158
    # one token at a time. This makes more sense than truncating an equal percent
1159
    # of tokens from each, since if one sequence is very short then each token
1160
    # that's truncated likely contains more information than a longer sequence.
1161
    while True:
1162
        #total_length = len(tokens_a) + len(tokens_b)
1163
        total_length = len(tokens_a)
1164
        if total_length <= max_length:
1165
            break
1166
        else:
1167
            tokens_a.pop()
1168

1169

1170
def accuracy(out, labels):
1171
    outputs = np.argmax(out, axis=1)
1172
    return np.sum(outputs == labels)
1173

1174

1175
if __name__ == "__main__":
1176
    main()
1177

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.