lmops

Форк
0
/
train_biencoder.py 
75 строк · 2.2 Кб
1
import logging
2

3
from transformers.utils.logging import enable_explicit_format, set_verbosity_info, set_verbosity_warning
4
from transformers.trainer_callback import PrinterCallback
5
from transformers import (
6
    AutoTokenizer,
7
    HfArgumentParser,
8
    Trainer,
9
    set_seed,
10
    PreTrainedTokenizerFast
11
)
12

13
from logger_config import logger, LoggerCallback
14
from config import Arguments
15
from trainers import BiencoderTrainer
16
from loaders import RetrievalDataLoader
17
from collators import BiencoderCollator
18
from models import BiencoderModel
19

20

21
def _common_setup(args: Arguments):
22
    set_verbosity_info()
23
    if args.process_index > 0:
24
        logger.setLevel(logging.WARNING)
25
        set_verbosity_warning()
26
    enable_explicit_format()
27
    set_seed(args.seed)
28

29

30
def main():
31
    parser = HfArgumentParser((Arguments,))
32
    args: Arguments = parser.parse_args_into_dataclasses()[0]
33
    _common_setup(args)
34
    logger.info('Args={}'.format(str(args)))
35

36
    tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path)
37
    model: BiencoderModel = BiencoderModel.build(args=args)
38
    logger.info(model)
39
    logger.info('Vocab size: {}'.format(len(tokenizer)))
40

41
    data_collator = BiencoderCollator(
42
        args=args,
43
        tokenizer=tokenizer,
44
        pad_to_multiple_of=8 if args.fp16 else None)
45

46
    retrieval_data_loader = RetrievalDataLoader(args=args, tokenizer=tokenizer)
47
    train_dataset = retrieval_data_loader.train_dataset
48

49
    trainer: Trainer = BiencoderTrainer(
50
        model=model,
51
        args=args,
52
        train_dataset=train_dataset if args.do_train else None,
53
        data_collator=data_collator,
54
        tokenizer=tokenizer,
55
    )
56
    trainer.remove_callback(PrinterCallback)
57
    trainer.add_callback(LoggerCallback)
58
    retrieval_data_loader.set_trainer(trainer)
59
    model.trainer = trainer
60

61
    if args.do_train:
62
        train_result = trainer.train()
63
        trainer.save_model()
64

65
        metrics = train_result.metrics
66
        metrics["train_samples"] = len(train_dataset)
67

68
        trainer.log_metrics("train", metrics)
69
        trainer.save_metrics("train", metrics)
70

71
    return
72

73

74
if __name__ == "__main__":
75
    main()
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.