2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
9
# http://www.apache.org/licenses/LICENSE-2.0
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
"""BERT finetuning runner."""
18
from __future__ import absolute_import, division, print_function, unicode_literals
28
from torch.utils.data import DataLoader, Dataset, RandomSampler
29
from torch.utils.data.distributed import DistributedSampler
30
from tqdm import tqdm, trange
32
from pytorch_pretrained_bert.modeling import BertForPreTraining
33
from pytorch_pretrained_bert.tokenization import BertTokenizer
34
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
36
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
37
datefmt='%m/%d/%Y %H:%M:%S',
39
logger = logging.getLogger(__name__)
42
class BERTDataset(Dataset):
43
def __init__(self, corpus_path, tokenizer, seq_len, encoding="utf-8", corpus_lines=None, on_memory=True):
44
self.vocab = tokenizer.vocab
45
self.tokenizer = tokenizer
46
self.seq_len = seq_len
47
self.on_memory = on_memory
48
self.corpus_lines = corpus_lines # number of non-empty lines in input corpus
49
self.corpus_path = corpus_path
50
self.encoding = encoding
51
self.current_doc = 0 # to avoid random sentence from same doc
53
# for loading samples directly from file
54
self.sample_counter = 0 # used to keep track of full epochs on file
55
self.line_buffer = None # keep second sentence of a pair in memory and use as first sentence in next pair
57
# for loading samples in memory
58
self.current_random_doc = 0
60
self.sample_to_doc = [] # map sample index to doc and line
62
# load samples into memory
67
with open(corpus_path, "r", encoding=encoding) as f:
68
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
71
self.all_docs.append(doc)
73
#remove last added sample because there won't be a subsequent line anymore in the doc
74
self.sample_to_doc.pop()
77
sample = {"doc_id": len(self.all_docs),
79
self.sample_to_doc.append(sample)
81
self.corpus_lines = self.corpus_lines + 1
83
# if last row in file is not empty
84
if self.all_docs[-1] != doc:
85
self.all_docs.append(doc)
86
self.sample_to_doc.pop()
88
self.num_docs = len(self.all_docs)
90
# load samples later lazily from disk
92
if self.corpus_lines is None:
93
with open(corpus_path, "r", encoding=encoding) as f:
95
for line in tqdm(f, desc="Loading Dataset", total=corpus_lines):
96
if line.strip() == "":
99
self.corpus_lines += 1
101
# if doc does not end with empty line
102
if line.strip() != "":
105
self.file = open(corpus_path, "r", encoding=encoding)
106
self.random_file = open(corpus_path, "r", encoding=encoding)
109
# last line of doc won't be used, because there's no "nextSentence". Additionally, we start counting at 0.
110
return self.corpus_lines - self.num_docs - 1
112
def __getitem__(self, item):
113
cur_id = self.sample_counter
114
self.sample_counter += 1
115
if not self.on_memory:
116
# after one epoch we start again from beginning of file
117
if cur_id != 0 and (cur_id % len(self) == 0):
119
self.file = open(self.corpus_path, "r", encoding=self.encoding)
121
t1, t2, is_next_label = self.random_sent(item)
124
tokens_a = self.tokenizer.tokenize(t1)
125
tokens_b = self.tokenizer.tokenize(t2)
127
# combine to one sample
128
cur_example = InputExample(guid=cur_id, tokens_a=tokens_a, tokens_b=tokens_b, is_next=is_next_label)
130
# transform sample to features
131
cur_features = convert_example_to_features(cur_example, self.seq_len, self.tokenizer)
133
cur_tensors = (torch.tensor(cur_features.input_ids),
134
torch.tensor(cur_features.input_mask),
135
torch.tensor(cur_features.segment_ids),
136
torch.tensor(cur_features.lm_label_ids),
137
torch.tensor(cur_features.is_next))
141
def random_sent(self, index):
143
Get one sample from corpus consisting of two sentences. With prob. 50% these are two subsequent sentences
144
from one doc. With 50% the second sentence will be a random one from another doc.
145
:param index: int, index of sample.
146
:return: (str, str, int), sentence 1, sentence 2, isNextSentence Label
148
t1, t2 = self.get_corpus_line(index)
149
if random.random() > 0.5:
152
t2 = self.get_random_line()
159
def get_corpus_line(self, item):
161
Get one sample from corpus consisting of a pair of two subsequent lines from the same doc.
162
:param item: int, index of sample.
163
:return: (str, str), two subsequent sentences from corpus
167
assert item < self.corpus_lines
169
sample = self.sample_to_doc[item]
170
t1 = self.all_docs[sample["doc_id"]][sample["line"]]
171
t2 = self.all_docs[sample["doc_id"]][sample["line"]+1]
172
# used later to avoid random nextSentence from same doc
173
self.current_doc = sample["doc_id"]
176
if self.line_buffer is None:
177
# read first non-empty line of file
179
t1 = next(self.file).strip()
180
t2 = next(self.file).strip()
182
# use t2 from previous iteration as new t1
183
t1 = self.line_buffer
184
t2 = next(self.file).strip()
185
# skip empty rows that are used for separating documents and keep track of current doc id
186
while t2 == "" or t1 == "":
187
t1 = next(self.file).strip()
188
t2 = next(self.file).strip()
189
self.current_doc = self.current_doc+1
190
self.line_buffer = t2
196
def get_random_line(self):
198
Get random line from another document for nextSentence task.
199
:return: str, content of one line
201
# Similar to original tf repo: This outer loop should rarely go for more than one iteration for large
202
# corpora. However, just to be careful, we try to make sure that
203
# the random document is not the same as the document we're processing.
206
rand_doc_idx = random.randint(0, len(self.all_docs)-1)
207
rand_doc = self.all_docs[rand_doc_idx]
208
line = rand_doc[random.randrange(len(rand_doc))]
210
rand_index = random.randint(1, self.corpus_lines if self.corpus_lines < 1000 else 1000)
212
for _ in range(rand_index):
213
line = self.get_next_line()
214
#check if our picked random line is really from another doc like we want it to be
215
if self.current_random_doc != self.current_doc:
219
def get_next_line(self):
220
""" Gets next line of random_file and starts over when reaching end of file"""
222
line = next(self.random_file).strip()
223
#keep track of which document we are currently looking at to later avoid having the same doc as t1
225
self.current_random_doc = self.current_random_doc + 1
226
line = next(self.random_file).strip()
227
except StopIteration:
228
self.random_file.close()
229
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
230
line = next(self.random_file).strip()
234
class InputExample(object):
235
"""A single training/test example for the language model."""
237
def __init__(self, guid, tokens_a, tokens_b=None, is_next=None, lm_labels=None):
238
"""Constructs a InputExample.
240
guid: Unique id for the example.
241
tokens_a: string. The untokenized text of the first sequence. For single
242
sequence tasks, only this sequence must be specified.
243
tokens_b: (Optional) string. The untokenized text of the second sequence.
244
Only must be specified for sequence pair tasks.
245
label: (Optional) string. The label of the example. This should be
246
specified for train and dev examples, but not for test examples.
249
self.tokens_a = tokens_a
250
self.tokens_b = tokens_b
251
self.is_next = is_next # nextSentence
252
self.lm_labels = lm_labels # masked words for language model
255
class InputFeatures(object):
256
"""A single set of features of data."""
258
def __init__(self, input_ids, input_mask, segment_ids, is_next, lm_label_ids):
259
self.input_ids = input_ids
260
self.input_mask = input_mask
261
self.segment_ids = segment_ids
262
self.is_next = is_next
263
self.lm_label_ids = lm_label_ids
266
def random_word(tokens, tokenizer):
268
Masking some random tokens for Language Model task with probabilities as in the original BERT paper.
269
:param tokens: list of str, tokenized sentence.
270
:param tokenizer: Tokenizer, object used for tokenization (we need it's vocab here)
271
:return: (list of str, list of int), masked tokens and related labels for LM prediction
275
for i, token in enumerate(tokens):
276
prob = random.random()
277
# mask token with 15% probability
281
# 80% randomly change token to mask token
285
# 10% randomly change token to random token
287
tokens[i] = random.choice(list(tokenizer.vocab.items()))[0]
289
# -> rest 10% randomly keep current token
291
# append current token to output (we will predict these later)
293
output_label.append(tokenizer.vocab[token])
295
# For unknown words (should not occur with BPE vocab)
296
output_label.append(tokenizer.vocab["[UNK]"])
297
logger.warning("Cannot find token '{}' in vocab. Using [UNK] insetad".format(token))
299
# no masking token (will be ignored by loss function later)
300
output_label.append(-1)
302
return tokens, output_label
305
def convert_example_to_features(example, max_seq_length, tokenizer):
307
Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample with
308
IDs, LM labels, input_mask, CLS and SEP tokens etc.
309
:param example: InputExample, containing sentence input as strings and is_next label
310
:param max_seq_length: int, maximum length of sequence.
311
:param tokenizer: Tokenizer
312
:return: InputFeatures, containing all inputs and labels of one sample as IDs (as used for model training)
314
tokens_a = example.tokens_a
315
tokens_b = example.tokens_b
316
# Modifies `tokens_a` and `tokens_b` in place so that the total
317
# length is less than the specified length.
318
# Account for [CLS], [SEP], [SEP] with "- 3"
319
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
321
tokens_a, t1_label = random_word(tokens_a, tokenizer)
322
tokens_b, t2_label = random_word(tokens_b, tokenizer)
323
# concatenate lm labels and account for CLS, SEP, SEP
324
lm_label_ids = ([-1] + t1_label + [-1] + t2_label + [-1])
326
# The convention in BERT is:
327
# (a) For sequence pairs:
328
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
329
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
330
# (b) For single sequences:
331
# tokens: [CLS] the dog is hairy . [SEP]
332
# type_ids: 0 0 0 0 0 0 0
334
# Where "type_ids" are used to indicate whether this is the first
335
# sequence or the second sequence. The embedding vectors for `type=0` and
336
# `type=1` were learned during pre-training and are added to the wordpiece
337
# embedding vector (and position vector). This is not *strictly* necessary
338
# since the [SEP] token unambigiously separates the sequences, but it makes
339
# it easier for the model to learn the concept of sequences.
341
# For classification tasks, the first vector (corresponding to [CLS]) is
342
# used as as the "sentence vector". Note that this only makes sense because
343
# the entire model is fine-tuned.
346
tokens.append("[CLS]")
347
segment_ids.append(0)
348
for token in tokens_a:
350
segment_ids.append(0)
351
tokens.append("[SEP]")
352
segment_ids.append(0)
354
assert len(tokens_b) > 0
355
for token in tokens_b:
357
segment_ids.append(1)
358
tokens.append("[SEP]")
359
segment_ids.append(1)
361
input_ids = tokenizer.convert_tokens_to_ids(tokens)
363
# The mask has 1 for real tokens and 0 for padding tokens. Only real
364
# tokens are attended to.
365
input_mask = [1] * len(input_ids)
367
# Zero-pad up to the sequence length.
368
while len(input_ids) < max_seq_length:
371
segment_ids.append(0)
372
lm_label_ids.append(-1)
374
assert len(input_ids) == max_seq_length
375
assert len(input_mask) == max_seq_length
376
assert len(segment_ids) == max_seq_length
377
assert len(lm_label_ids) == max_seq_length
380
logger.info("*** Example ***")
381
logger.info("guid: %s" % (example.guid))
382
logger.info("tokens: %s" % " ".join(
383
[str(x) for x in tokens]))
384
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
385
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
387
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
388
logger.info("LM label: %s " % (lm_label_ids))
389
logger.info("Is next sentence label: %s " % (example.is_next))
391
features = InputFeatures(input_ids=input_ids,
392
input_mask=input_mask,
393
segment_ids=segment_ids,
394
lm_label_ids=lm_label_ids,
395
is_next=example.is_next)
400
parser = argparse.ArgumentParser()
402
## Required parameters
403
parser.add_argument("--train_corpus",
407
help="The input train corpus.")
408
parser.add_argument("--bert_model", default=None, type=str, required=True,
409
help="Bert pre-trained model selected in the list: bert-base-uncased, "
410
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
411
parser.add_argument("--output_dir",
415
help="The output directory where the model checkpoints will be written.")
418
parser.add_argument("--max_seq_length",
421
help="The maximum total input sequence length after WordPiece tokenization. \n"
422
"Sequences longer than this will be truncated, and sequences shorter \n"
423
"than this will be padded.")
424
parser.add_argument("--do_train",
426
help="Whether to run training.")
427
parser.add_argument("--train_batch_size",
430
help="Total batch size for training.")
431
parser.add_argument("--learning_rate",
434
help="The initial learning rate for Adam.")
435
parser.add_argument("--num_train_epochs",
438
help="Total number of training epochs to perform.")
439
parser.add_argument("--warmup_proportion",
442
help="Proportion of training to perform linear learning rate warmup for. "
443
"E.g., 0.1 = 10%% of training.")
444
parser.add_argument("--no_cuda",
446
help="Whether not to use CUDA when available")
447
parser.add_argument("--on_memory",
449
help="Whether to load train samples into memory or use disk")
450
parser.add_argument("--do_lower_case",
452
help="Whether to lower case the input text. True for uncased models, False for cased models.")
453
parser.add_argument("--local_rank",
456
help="local_rank for distributed training on gpus")
457
parser.add_argument('--seed',
460
help="random seed for initialization")
461
parser.add_argument('--gradient_accumulation_steps',
464
help="Number of updates steps to accumualte before performing a backward/update pass.")
465
parser.add_argument('--fp16',
467
help="Whether to use 16-bit float precision instead of 32-bit")
468
parser.add_argument('--loss_scale',
469
type = float, default = 0,
470
help = "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
471
"0 (default value): dynamic loss scaling.\n"
472
"Positive power of 2: static loss scaling value.\n")
474
args = parser.parse_args()
476
if args.local_rank == -1 or args.no_cuda:
477
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
478
n_gpu = torch.cuda.device_count()
480
torch.cuda.set_device(args.local_rank)
481
device = torch.device("cuda", args.local_rank)
483
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
484
torch.distributed.init_process_group(backend='nccl')
485
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
486
device, n_gpu, bool(args.local_rank != -1), args.fp16))
488
if args.gradient_accumulation_steps < 1:
489
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
490
args.gradient_accumulation_steps))
492
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
494
random.seed(args.seed)
495
np.random.seed(args.seed)
496
torch.manual_seed(args.seed)
498
torch.cuda.manual_seed_all(args.seed)
500
if not args.do_train:
501
raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")
503
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
504
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
505
if not os.path.exists(args.output_dir):
506
os.makedirs(args.output_dir)
508
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
510
#train_examples = None
511
num_train_optimization_steps = None
513
print("Loading Train Dataset", args.train_corpus)
514
train_dataset = BERTDataset(args.train_corpus, tokenizer, seq_len=args.max_seq_length,
515
corpus_lines=None, on_memory=args.on_memory)
516
num_train_optimization_steps = int(
517
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
518
if args.local_rank != -1:
519
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
522
model = BertForPreTraining.from_pretrained(args.bert_model)
526
if args.local_rank != -1:
528
from apex.parallel import DistributedDataParallel as DDP
530
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
533
model = torch.nn.DataParallel(model)
537
param_optimizer = list(model.named_parameters())
538
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
539
optimizer_grouped_parameters = [
540
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
541
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
546
from apex.optimizers import FP16_Optimizer
547
from apex.optimizers import FusedAdam
549
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
551
optimizer = FusedAdam(optimizer_grouped_parameters,
552
lr=args.learning_rate,
553
bias_correction=False,
555
if args.loss_scale == 0:
556
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
558
optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
559
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
560
t_total=num_train_optimization_steps)
563
optimizer = BertAdam(optimizer_grouped_parameters,
564
lr=args.learning_rate,
565
warmup=args.warmup_proportion,
566
t_total=num_train_optimization_steps)
570
logger.info("***** Running training *****")
571
logger.info(" Num examples = %d", len(train_dataset))
572
logger.info(" Batch size = %d", args.train_batch_size)
573
logger.info(" Num steps = %d", num_train_optimization_steps)
575
if args.local_rank == -1:
576
train_sampler = RandomSampler(train_dataset)
578
#TODO: check if this works with current data generator from disk that relies on next(file)
579
# (it doesn't return item back by index)
580
train_sampler = DistributedSampler(train_dataset)
581
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
584
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
586
nb_tr_examples, nb_tr_steps = 0, 0
587
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
588
batch = tuple(t.to(device) for t in batch)
589
input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
590
loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
592
loss = loss.mean() # mean() to average on multi-gpu.
593
if args.gradient_accumulation_steps > 1:
594
loss = loss / args.gradient_accumulation_steps
596
optimizer.backward(loss)
599
tr_loss += loss.item()
600
nb_tr_examples += input_ids.size(0)
602
if (step + 1) % args.gradient_accumulation_steps == 0:
604
# modify learning rate with special warm up BERT uses
605
# if args.fp16 is False, BertAdam is used that handles this automatically
606
lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
607
for param_group in optimizer.param_groups:
608
param_group['lr'] = lr_this_step
610
optimizer.zero_grad()
613
# Save a trained model
614
logger.info("** ** * Saving fine - tuned model ** ** * ")
615
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
616
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
618
torch.save(model_to_save.state_dict(), output_model_file)
621
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
622
"""Truncates a sequence pair in place to the maximum length."""
624
# This is a simple heuristic which will always truncate the longer sequence
625
# one token at a time. This makes more sense than truncating an equal percent
626
# of tokens from each, since if one sequence is very short then each token
627
# that's truncated likely contains more information than a longer sequence.
629
total_length = len(tokens_a) + len(tokens_b)
630
if total_length <= max_length:
632
if len(tokens_a) > len(tokens_b):
638
def accuracy(out, labels):
639
outputs = np.argmax(out, axis=1)
640
return np.sum(outputs == labels)
643
if __name__ == "__main__":