colossalai

Форк
0
/
detached_trainer_base.py 
179 строк · 6.8 Кб
1
import os
2
from abc import ABC, abstractmethod
3
from typing import Any, Dict, List
4

5
import ray
6
import torch
7
from coati.experience_buffer.utils import BufferItem
8
from coati.experience_maker import Experience
9
from torch.utils.data import DataLoader
10
from tqdm import tqdm
11

12
from .callbacks import TrainerCallback
13
from .detached_replay_buffer import DetachedReplayBuffer
14
from .utils import is_rank_0
15

16

17
class DetachedTrainer(ABC):
18
    """
19
        Base class for detached rlhf trainers.
20
        'detach' means that the experience maker is detached compared to a normal Trainer.
21
        Please set name attribute during init:
22
            >>> trainer = DetachedTrainer.options(..., name = "xxx", ...).remote()
23
            So an ExperienceMakerHolder can reach the detached_replay_buffer by Actor's name.
24
    Args:
25
        detached_strategy (DetachedStrategy): the strategy to use for training
26
        detached_replay_buffer_ref (ObjectRef[DetachedReplayBuffer]): the replay buffer to use for training
27
        data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
28
        callbacks (List[Callback], defaults to []): the callbacks to call during training process
29
        generate_kwargs (dict, optional): the kwargs to use while model generating
30

31
    """
32

33
    def __init__(
34
        self,
35
        experience_maker_holder_name_list: List[str],
36
        train_batch_size: int = 8,
37
        buffer_limit: int = 0,
38
        dataloader_pin_memory: bool = True,
39
        callbacks: List[TrainerCallback] = [],
40
        debug: bool = False,
41
    ) -> None:
42
        super().__init__()
43
        self.detached_replay_buffer = DetachedReplayBuffer(train_batch_size, limit=buffer_limit)
44
        self.dataloader_pin_memory = dataloader_pin_memory
45
        self.callbacks = callbacks
46
        self.target_holder_name_list = experience_maker_holder_name_list
47
        self.target_holder_list = []
48
        self._is_target_holder_initialized = False
49
        self._debug = debug
50

51
    def update_target_holder_list(self):
52
        # as the length of target_holder_list may be zero, we need to check it by a bool flag
53
        if not self._is_target_holder_initialized:
54
            for name in self.target_holder_name_list:
55
                self.target_holder_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
56
            self._is_target_holder_initialized = True
57

58
    @abstractmethod
59
    def _update_remote_makers(self, fully_update: bool = False, **kwargs):
60
        pass
61

62
    def sync_models_to_remote_makers(self, **kwargs):
63
        self._update_remote_makers(fully_update=True, **kwargs)
64

65
    @abstractmethod
66
    def training_step(self, experience: Experience) -> Dict[str, Any]:
67
        pass
68

69
    def _learn(self, update_steps: int, train_epochs: int) -> None:
70
        data = []
71
        # warmup
72
        pbar = tqdm(range(update_steps), desc=f"Train epoch [1/{train_epochs}]", disable=not is_rank_0())
73
        self._on_epoch_start(0)
74
        self._learn_epoch(pbar, data)
75
        self._on_epoch_end(0)
76
        # item is already a batch
77
        dataloader = DataLoader(
78
            data, batch_size=1, shuffle=True, pin_memory=self.dataloader_pin_memory, collate_fn=lambda x: x[0]
79
        )
80
        for epoch in range(1, train_epochs):
81
            pbar = tqdm(dataloader, desc=f"Train epoch [{epoch + 1}/{train_epochs}]", disable=not is_rank_0())
82
            self._on_epoch_start(epoch)
83
            self._learn_epoch(pbar, data)
84
            self._on_epoch_end(epoch)
85

86
    def _learn_epoch(self, pbar: tqdm, data: List[Experience]) -> None:
87
        is_warmup = len(data) == 0
88
        for x in pbar:
89
            if self._debug:
90
                print("[trainer] training step")
91
            # sample a batch and then train to avoid waiting
92
            experience = x if not is_warmup else self._buffer_sample()
93
            experience.to_device(torch.cuda.current_device())
94
            self._on_batch_start()
95
            metrics = self.training_step(experience)
96
            self._on_batch_end(metrics, experience)
97

98
            if self._debug:
99
                print("[trainer] step over")
100
            experience.to_device("cpu")
101
            if is_warmup:
102
                data.append(experience)
103
            pbar.set_postfix(metrics)
104

105
    def fit(self, total_steps: int, update_steps: int, train_epochs: int = 1) -> None:
106
        self._on_fit_start()
107
        for i in tqdm(range(total_steps // update_steps), desc="Trainer", disable=not is_rank_0()):
108
            self._on_episode_start(i)
109
            self._learn(update_steps, train_epochs)
110
            self._on_update_start()
111
            self._update_remote_makers()
112
            self._on_update_end()
113
            self._on_episode_end(i)
114
        self._on_fit_end()
115

116
    @ray.method(concurrency_group="buffer_length")
117
    def buffer_get_length(self):
118
        # called by ExperienceMakerHolder
119
        if self._debug:
120
            print("[trainer]                telling length")
121
        return self.detached_replay_buffer.get_length()
122

123
    @ray.method(concurrency_group="buffer_append")
124
    def buffer_append(self, experience: Experience):
125
        # called by ExperienceMakerHolder
126
        if self._debug:
127
            print(f"[trainer]               receiving exp.")
128
        self.detached_replay_buffer.append(experience)
129

130
    @ray.method(concurrency_group="buffer_append")
131
    def buffer_extend(self, items: List[BufferItem]):
132
        # called by ExperienceMakerHolder
133
        if self._debug:
134
            print(f"[trainer]               receiving exp.")
135
        self.detached_replay_buffer.extend(items)
136

137
    @ray.method(concurrency_group="buffer_sample")
138
    def _buffer_sample(self):
139
        return self.detached_replay_buffer.sample()
140

141
    def _on_fit_start(self) -> None:
142
        for callback in self.callbacks:
143
            callback.on_fit_start()
144

145
    def _on_fit_end(self) -> None:
146
        for callback in self.callbacks:
147
            callback.on_fit_end()
148

149
    def _on_episode_start(self, episode: int) -> None:
150
        for callback in self.callbacks:
151
            callback.on_episode_start(episode)
152

153
    def _on_episode_end(self, episode: int) -> None:
154
        for callback in self.callbacks:
155
            callback.on_episode_end(episode)
156

157
    def _on_epoch_start(self, epoch: int) -> None:
158
        for callback in self.callbacks:
159
            callback.on_epoch_start(epoch)
160

161
    def _on_epoch_end(self, epoch: int) -> None:
162
        for callback in self.callbacks:
163
            callback.on_epoch_end(epoch)
164

165
    def _on_batch_start(self) -> None:
166
        for callback in self.callbacks:
167
            callback.on_batch_start()
168

169
    def _on_batch_end(self, metrics: dict, experience: Experience) -> None:
170
        for callback in self.callbacks:
171
            callback.on_batch_end(metrics, experience)
172

173
    def _on_update_start(self) -> None:
174
        for callback in self.callbacks:
175
            callback.on_update_start()
176

177
    def _on_update_end(self) -> None:
178
        for callback in self.callbacks:
179
            callback.on_update_end()
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.