colossalai

Форк
0
130 строк · 4.8 Кб
1
from typing import Optional
2

3
import torch
4
import torch.distributed as dist
5
import tqdm
6
from torch.optim import Optimizer
7
from torch.optim.lr_scheduler import _LRScheduler
8
from torch.utils.data import DataLoader
9

10
from colossalai.logging import DistributedLogger
11

12
from .base import SLTrainer
13
from .strategies import GeminiStrategy, Strategy
14
from .utils import is_rank_0, to_device
15

16

17
class SFTTrainer(SLTrainer):
18
    """
19
        Trainer to use while training reward model.
20

21
    Args:
22
        model (torch.nn.Module): the model to train
23
        strategy (Strategy): the strategy to use for training
24
        optim(Optimizer): the optimizer to use for training
25
        lr_scheduler(_LRScheduler): the lr scheduler to use for training
26
        max_epochs (int, defaults to 2): the number of epochs to train
27
        accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
28
    """
29

30
    def __init__(
31
        self,
32
        model,
33
        strategy: Strategy,
34
        optim: Optimizer,
35
        lr_scheduler: _LRScheduler,
36
        max_epochs: int = 2,
37
        accumulation_steps: int = 8,
38
    ) -> None:
39
        if accumulation_steps > 1:
40
            assert not isinstance(
41
                strategy, GeminiStrategy
42
            ), "Accumulation steps are not supported in stage 3 of ColossalAI"
43

44
        super().__init__(strategy, max_epochs, model, optim)
45

46
        self.accumulation_steps = accumulation_steps
47
        self.scheduler = lr_scheduler
48

49
        self.num_train_step = 0
50
        self.num_eval_step = 0
51

52
    def _train(self, epoch: int):
53
        self.model.train()
54
        step_bar = tqdm.trange(
55
            len(self.train_dataloader) // self.accumulation_steps,
56
            desc=f"Epoch {epoch + 1}/{self.max_epochs}",
57
            disable=not is_rank_0(),
58
        )
59
        for i, batch in enumerate(self.train_dataloader):
60
            batch = to_device(batch, torch.cuda.current_device())
61
            outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
62
            loss = outputs.loss / self.accumulation_steps
63
            self.total_loss += loss.item()
64
            self.strategy.backward(loss, self.model, self.optimizer)
65
            # gradient accumulation
66
            if (i + 1) % self.accumulation_steps == 0:
67
                self.strategy.optimizer_step(self.optimizer)
68
                self.optimizer.zero_grad()
69
                self.scheduler.step()
70
                if self.writer:
71
                    self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
72
                    self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
73
                    self.num_train_step += 1
74
                self.total_loss = 0
75
                step_bar.update()
76
        step_bar.close()
77

78
    def _eval(self, epoch: int):
79
        if self.eval_dataloader is not None:
80
            self.model.eval()
81
            with torch.no_grad():
82
                loss_sum, num_seen = 0, 0
83
                for batch in self.eval_dataloader:
84
                    batch = to_device(batch, torch.cuda.current_device())
85
                    outputs = self.model(
86
                        batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
87
                    )
88
                    loss_sum += outputs.loss.item()
89
                    num_seen += batch["input_ids"].size(0)
90
                loss_mean = loss_sum / num_seen
91
                if dist.get_rank() == 0:
92
                    self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
93
                if self.writer:
94
                    self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
95
                    self.num_eval_step += 1
96

97
    def _before_fit(
98
        self,
99
        train_dataloader: DataLoader,
100
        eval_dataloader: Optional[DataLoader] = None,
101
        logger: Optional[DistributedLogger] = None,
102
        log_dir: Optional[str] = None,
103
        use_wandb: bool = False,
104
    ):
105
        """
106
        Args:
107
            train_dataloader: the dataloader to use for training
108
            eval_dataloader: the dataloader to use for evaluation
109
        """
110
        self.train_dataloader = train_dataloader
111
        self.eval_dataloader = eval_dataloader
112

113
        self.logger = logger
114
        self.writer = None
115
        if use_wandb and is_rank_0():
116
            assert log_dir is not None, "log_dir must be provided when use_wandb is True"
117
            import wandb
118

119
            wandb.init(project="Coati-sft", sync_tensorboard=True)
120
        if log_dir is not None and is_rank_0():
121
            import os
122
            import time
123

124
            from torch.utils.tensorboard import SummaryWriter
125

126
            log_dir = os.path.join(log_dir, "sft")
127
            log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
128
            self.writer = SummaryWriter(log_dir=log_dir)
129

130
        self.total_loss = 0
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.