colossalai
130 строк · 4.8 Кб
1from typing import Optional
2
3import torch
4import torch.distributed as dist
5import tqdm
6from torch.optim import Optimizer
7from torch.optim.lr_scheduler import _LRScheduler
8from torch.utils.data import DataLoader
9
10from colossalai.logging import DistributedLogger
11
12from .base import SLTrainer
13from .strategies import GeminiStrategy, Strategy
14from .utils import is_rank_0, to_device
15
16
17class SFTTrainer(SLTrainer):
18"""
19Trainer to use while training reward model.
20
21Args:
22model (torch.nn.Module): the model to train
23strategy (Strategy): the strategy to use for training
24optim(Optimizer): the optimizer to use for training
25lr_scheduler(_LRScheduler): the lr scheduler to use for training
26max_epochs (int, defaults to 2): the number of epochs to train
27accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
28"""
29
30def __init__(
31self,
32model,
33strategy: Strategy,
34optim: Optimizer,
35lr_scheduler: _LRScheduler,
36max_epochs: int = 2,
37accumulation_steps: int = 8,
38) -> None:
39if accumulation_steps > 1:
40assert not isinstance(
41strategy, GeminiStrategy
42), "Accumulation steps are not supported in stage 3 of ColossalAI"
43
44super().__init__(strategy, max_epochs, model, optim)
45
46self.accumulation_steps = accumulation_steps
47self.scheduler = lr_scheduler
48
49self.num_train_step = 0
50self.num_eval_step = 0
51
52def _train(self, epoch: int):
53self.model.train()
54step_bar = tqdm.trange(
55len(self.train_dataloader) // self.accumulation_steps,
56desc=f"Epoch {epoch + 1}/{self.max_epochs}",
57disable=not is_rank_0(),
58)
59for i, batch in enumerate(self.train_dataloader):
60batch = to_device(batch, torch.cuda.current_device())
61outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
62loss = outputs.loss / self.accumulation_steps
63self.total_loss += loss.item()
64self.strategy.backward(loss, self.model, self.optimizer)
65# gradient accumulation
66if (i + 1) % self.accumulation_steps == 0:
67self.strategy.optimizer_step(self.optimizer)
68self.optimizer.zero_grad()
69self.scheduler.step()
70if self.writer:
71self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
72self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
73self.num_train_step += 1
74self.total_loss = 0
75step_bar.update()
76step_bar.close()
77
78def _eval(self, epoch: int):
79if self.eval_dataloader is not None:
80self.model.eval()
81with torch.no_grad():
82loss_sum, num_seen = 0, 0
83for batch in self.eval_dataloader:
84batch = to_device(batch, torch.cuda.current_device())
85outputs = self.model(
86batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
87)
88loss_sum += outputs.loss.item()
89num_seen += batch["input_ids"].size(0)
90loss_mean = loss_sum / num_seen
91if dist.get_rank() == 0:
92self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
93if self.writer:
94self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
95self.num_eval_step += 1
96
97def _before_fit(
98self,
99train_dataloader: DataLoader,
100eval_dataloader: Optional[DataLoader] = None,
101logger: Optional[DistributedLogger] = None,
102log_dir: Optional[str] = None,
103use_wandb: bool = False,
104):
105"""
106Args:
107train_dataloader: the dataloader to use for training
108eval_dataloader: the dataloader to use for evaluation
109"""
110self.train_dataloader = train_dataloader
111self.eval_dataloader = eval_dataloader
112
113self.logger = logger
114self.writer = None
115if use_wandb and is_rank_0():
116assert log_dir is not None, "log_dir must be provided when use_wandb is True"
117import wandb
118
119wandb.init(project="Coati-sft", sync_tensorboard=True)
120if log_dir is not None and is_rank_0():
121import os
122import time
123
124from torch.utils.tensorboard import SummaryWriter
125
126log_dir = os.path.join(log_dir, "sft")
127log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
128self.writer = SummaryWriter(log_dir=log_dir)
129
130self.total_loss = 0
131