lmops
75 строк · 2.2 Кб
1import logging
2
3from transformers.utils.logging import enable_explicit_format, set_verbosity_info, set_verbosity_warning
4from transformers.trainer_callback import PrinterCallback
5from transformers import (
6AutoTokenizer,
7HfArgumentParser,
8Trainer,
9set_seed,
10PreTrainedTokenizerFast
11)
12
13from logger_config import logger, LoggerCallback
14from config import Arguments
15from trainers import BiencoderTrainer
16from loaders import RetrievalDataLoader
17from collators import BiencoderCollator
18from models import BiencoderModel
19
20
21def _common_setup(args: Arguments):
22set_verbosity_info()
23if args.process_index > 0:
24logger.setLevel(logging.WARNING)
25set_verbosity_warning()
26enable_explicit_format()
27set_seed(args.seed)
28
29
30def main():
31parser = HfArgumentParser((Arguments,))
32args: Arguments = parser.parse_args_into_dataclasses()[0]
33_common_setup(args)
34logger.info('Args={}'.format(str(args)))
35
36tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
37model: BiencoderModel = BiencoderModel.build(args=args)
38logger.info(model)
39logger.info('Vocab size: {}'.format(len(tokenizer)))
40
41data_collator = BiencoderCollator(
42args=args,
43tokenizer=tokenizer,
44pad_to_multiple_of=8 if args.fp16 else None)
45
46retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer)
47train_dataset = retrieval_data_loader.train_dataset
48
49trainer: Trainer = BiencoderTrainer(
50model=model,
51args=args,
52train_dataset=train_dataset if args.do_train else None,
53data_collator=data_collator,
54tokenizer=tokenizer,
55)
56trainer.remove_callback(PrinterCallback)
57trainer.add_callback(LoggerCallback)
58retrieval_data_loader.set_trainer(trainer)
59model.trainer = trainer
60
61if args.do_train:
62train_result = trainer.train()
63trainer.save_model()
64
65metrics = train_result.metrics
66metrics["train_samples"] = len(train_dataset)
67
68trainer.log_metrics("train", metrics)
69trainer.save_metrics("train", metrics)
70
71return
72
73
74if __name__ == "__main__":
75main()
76