aurora

Форк
0
/
workflow.py 
94 строки · 4.2 Кб
1
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
2

3
from typing import TYPE_CHECKING, Optional, List
4
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
5

6
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
7
from llmtuner.extras.constants import IGNORE_INDEX
8
from llmtuner.extras.misc import get_logits_processor
9
from llmtuner.extras.ploting import plot_loss
10
from llmtuner.model import load_model_and_tokenizer
11
from llmtuner.train.sft.metric import ComputeMetrics
12
from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer
13
from llmtuner.train.utils import create_modelcard_and_push
14

15
if TYPE_CHECKING:
16
    from transformers import TrainerCallback
17
    from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
18

19

20
def run_sft(
21
    model_args: "ModelArguments",
22
    data_args: "DataArguments",
23
    training_args: "Seq2SeqTrainingArguments",
24
    finetuning_args: "FinetuningArguments",
25
    generating_args: "GeneratingArguments",
26
    callbacks: Optional[List["TrainerCallback"]] = None
27
):
28
    dataset = get_dataset(model_args, data_args)
29
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
30
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
31

32
    if training_args.predict_with_generate:
33
        tokenizer.padding_side = "left" # use left-padding in generation
34

35
    data_collator = DataCollatorForSeq2Seq(
36
        tokenizer=tokenizer,
37
        pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
38
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
39
    )
40

41
    # Override the decoding parameters of Seq2SeqTrainer
42
    training_args_dict = training_args.to_dict()
43
    training_args_dict.update(dict(
44
        generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
45
        generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
46
    ))
47
    training_args = Seq2SeqTrainingArguments(**training_args_dict)
48

49
    # Initialize our Trainer
50
    trainer = CustomSeq2SeqTrainer(
51
        model=model,
52
        args=training_args,
53
        tokenizer=tokenizer,
54
        data_collator=data_collator,
55
        callbacks=callbacks,
56
        compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
57
        **split_dataset(dataset, data_args, training_args)
58
    )
59

60
    # Keyword arguments for `model.generate`
61
    gen_kwargs = generating_args.to_dict()
62
    gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
63
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
64
    gen_kwargs["logits_processor"] = get_logits_processor()
65

66
    # Training
67
    if training_args.do_train:
68
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
69
        trainer.save_model()
70
        trainer.log_metrics("train", train_result.metrics)
71
        trainer.save_metrics("train", train_result.metrics)
72
        trainer.save_state()
73
        if trainer.is_world_process_zero() and finetuning_args.plot_loss:
74
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
75

76
    # Evaluation
77
    if training_args.do_eval:
78
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
79
        if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
80
            metrics.pop("eval_loss", None)
81
        trainer.log_metrics("eval", metrics)
82
        trainer.save_metrics("eval", metrics)
83

84
    # Predict
85
    if training_args.do_predict:
86
        predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
87
        if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
88
            predict_results.metrics.pop("predict_loss", None)
89
        trainer.log_metrics("predict", predict_results.metrics)
90
        trainer.save_metrics("predict", predict_results.metrics)
91
        trainer.save_predictions(predict_results)
92

93
    # Create model card
94
    create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.