colossalai

Форк
0
202 строки · 9.1 Кб
1
from typing import Dict, List, Optional
2

3
from coati.experience_buffer import NaiveExperienceBuffer
4
from coati.experience_maker import Experience, NaiveExperienceMaker
5
from coati.models.base import Actor, Critic, RewardModel, get_base_model
6
from coati.models.loss import GPTLMLoss, PolicyLoss, ValueLoss
7
from coati.models.utils import calc_action_log_probs
8
from torch.optim import Optimizer
9
from torch.utils.data import DataLoader, DistributedSampler
10
from tqdm import tqdm
11
from transformers import PreTrainedTokenizerBase
12

13
from colossalai.accelerator import get_accelerator
14

15
from .base import OnPolicyTrainer
16
from .callbacks import Callback
17
from .strategies import GeminiStrategy, Strategy
18
from .utils import CycledDataLoader, is_rank_0, to_device
19

20

21
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> Dict:
22
    unwrapped_model = strategy.unwrap_model(actor)
23
    hf_model = get_base_model(unwrapped_model)
24
    new_kwargs = {**generate_kwargs}
25
    # use huggingface models method directly
26
    if "prepare_inputs_fn" not in generate_kwargs and hasattr(hf_model, "prepare_inputs_for_generation"):
27
        new_kwargs["prepare_inputs_fn"] = hf_model.prepare_inputs_for_generation
28

29
    if "update_model_kwargs_fn" not in generate_kwargs and hasattr(hf_model, "_update_model_kwargs_for_generation"):
30
        new_kwargs["update_model_kwargs_fn"] = hf_model._update_model_kwargs_for_generation
31

32
    return new_kwargs
33

34

35
class PPOTrainer(OnPolicyTrainer):
36
    """
37
        Trainer for PPO algorithm.
38

39
    Args:
40
        strategy (Strategy): the strategy to use for training
41
        actor (Actor): the actor model in ppo algorithm
42
        critic (Critic): the critic model in ppo algorithm
43
        reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
44
        initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
45
        actor_optim (Optimizer): the optimizer to use for actor model
46
        critic_optim (Optimizer): the optimizer to use for critic model
47
        kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
48
        train_batch_size (int, defaults to 8): the batch size to use for training
49
        buffer_limit (int, defaults to 0): the max_size limitation of buffer
50
        buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
51
        eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
52
        vf_coef (float, defaults to 1.0): the coefficient of value loss
53
        ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
54
        value_clip (float, defaults to 0.4): the clip coefficient of value loss
55
        sample_buffer (bool, defaults to False): whether to sample from buffer
56
        dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
57
        offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
58
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
59
        generate_kwargs (dict, optional): the kwargs to use while model generating
60
    """
61

62
    def __init__(
63
        self,
64
        strategy: Strategy,
65
        actor: Actor,
66
        critic: Critic,
67
        reward_model: RewardModel,
68
        initial_model: Actor,
69
        actor_optim: Optimizer,
70
        critic_optim: Optimizer,
71
        tokenizer: PreTrainedTokenizerBase,
72
        kl_coef: float = 0.1,
73
        ptx_coef: float = 0.9,
74
        train_batch_size: int = 8,
75
        buffer_limit: int = 0,
76
        buffer_cpu_offload: bool = True,
77
        eps_clip: float = 0.2,
78
        vf_coef: float = 1.0,
79
        value_clip: float = 0.4,
80
        sample_buffer: bool = False,
81
        dataloader_pin_memory: bool = True,
82
        offload_inference_models: bool = True,
83
        callbacks: List[Callback] = [],
84
        **generate_kwargs,
85
    ) -> None:
86
        if isinstance(strategy, GeminiStrategy):
87
            assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
88

89
        data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
90
        super().__init__(strategy, data_buffer, sample_buffer, dataloader_pin_memory, callbacks)
91

92
        self.generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
93
        self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, tokenizer, kl_coef)
94

95
        self.actor = actor
96
        self.critic = critic
97
        self.tokenizer = tokenizer
98

99
        self.actor_loss_fn = PolicyLoss(eps_clip)
100
        self.critic_loss_fn = ValueLoss(value_clip)
101
        self.vf_coef = vf_coef
102
        self.ptx_loss_fn = GPTLMLoss()
103
        self.ptx_coef = ptx_coef
104
        self.actor_optim = actor_optim
105
        self.critic_optim = critic_optim
106

107
        self.offload_inference_models = offload_inference_models
108
        self.device = get_accelerator().get_current_device()
109

110
    def _before_fit(
111
        self,
112
        prompt_dataloader: DataLoader,
113
        pretrain_dataloader: DataLoader,
114
        log_dir: Optional[str] = None,
115
        use_wandb: bool = False,
116
    ):
117
        """
118
        Args:
119
            prompt_dataloader (DataLoader): the dataloader to use for prompt data
120
            pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
121
        """
122
        self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
123
        self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader)
124

125
        self.writer = None
126
        if use_wandb and is_rank_0():
127
            assert log_dir is not None, "log_dir must be provided when use_wandb is True"
128
            import wandb
129

130
            wandb.init(project="Coati-ppo", sync_tensorboard=True)
131
        if log_dir is not None and is_rank_0():
132
            import os
133
            import time
134

135
            from torch.utils.tensorboard import SummaryWriter
136

137
            log_dir = os.path.join(log_dir, "ppo")
138
            log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
139
            self.writer = SummaryWriter(log_dir=log_dir)
140

141
    def _make_experience(self, collect_step: int) -> Experience:
142
        prompts = self.prompt_dataloader.next()
143
        if self.offload_inference_models:
144
            # TODO(ver217): this may be controlled by strategy if they are prepared by strategy
145
            self.experience_maker.initial_model.to(self.device)
146
            self.experience_maker.reward_model.to(self.device)
147
        assert isinstance(prompts, dict), f'Unsupported input type "{type(prompts)}"'
148
        return self.experience_maker.make_experience(**prompts, **self.generate_kwargs)
149

150
    def _training_step(self, experience: Experience):
151
        self.actor.train()
152
        self.critic.train()
153
        # policy loss
154
        num_actions = experience.action_log_probs.size(1)
155
        actor_logits = self.actor(experience.sequences, experience.attention_mask)["logits"]
156
        action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
157
        actor_loss = self.actor_loss_fn(
158
            action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
159
        )
160
        actor_loss = (1 - self.ptx_coef) * actor_loss
161
        self.strategy.backward(actor_loss, self.actor, self.actor_optim)
162

163
        # ptx loss
164
        if self.ptx_coef != 0:
165
            batch = self.pretrain_dataloader.next()
166
            batch = to_device(batch, self.device)
167
            ptx_log_probs = self.actor(batch["input_ids"], batch["attention_mask"])["logits"]
168
            ptx_loss = self.ptx_coef * self.ptx_loss_fn(ptx_log_probs, batch["labels"])
169
            self.strategy.backward(ptx_loss, self.actor, self.actor_optim)
170

171
        self.strategy.optimizer_step(self.actor_optim)
172
        self.actor_optim.zero_grad()
173

174
        # value loss
175
        values = self.critic(experience.sequences, attention_mask=experience.attention_mask)
176
        critic_loss = self.critic_loss_fn(values, experience.values, experience.reward)
177
        critic_loss = critic_loss * self.vf_coef
178
        self.strategy.backward(critic_loss, self.critic, self.critic_optim)
179
        self.strategy.optimizer_step(self.critic_optim)
180
        self.critic_optim.zero_grad()
181

182
    def _learn(self, update_step: int):
183
        if self.offload_inference_models:
184
            self.experience_maker.initial_model.to("cpu")
185
            self.experience_maker.reward_model.to("cpu")
186

187
        # buffer may be empty at first, we should rebuild at each training
188
        if self.sample_buffer:
189
            experience = self.data_buffer.sample()
190
            self._on_learn_batch_start()
191
            experience.to_device(self.device)
192
            self._training_step(experience)
193
            self._on_learn_batch_end(experience)
194
        else:
195
            if isinstance(self.dataloader.sampler, DistributedSampler):
196
                self.dataloader.sampler.set_epoch(update_step)
197
            pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
198
            for experience in pbar:
199
                self._on_learn_batch_start()
200
                experience.to_device(self.device)
201
                self._training_step(experience)
202
                self._on_learn_batch_end(experience)
203

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.