colossalai

Форк
0
/
experience_maker_holder.py 
274 строки · 11.4 Кб
1
import os
2
import time
3
import tracemalloc
4
from threading import Lock
5
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
6

7
import ray
8
import torch
9
from coati.experience_buffer.utils import split_experience_batch
10
from coati.experience_maker import Experience, NaiveExperienceMaker
11
from coati.models.base import Actor, Critic, RewardModel
12
from coati.trainer.strategies import Strategy
13
from torch import Tensor
14
from tqdm import tqdm
15

16
from .callbacks import ExperienceMakerPerformanceEvaluator, MakerCallback
17
from .lora_constructor import LoRAConstructor
18
from .utils import get_model_numel, get_rank, is_rank_0, set_dist_env, state_dict_to
19

20

21
@ray.remote(concurrency_groups={"experience_io": 1, "model_io": 1, "compute": 1})
22
class ExperienceMakerHolder:
23
    """
24
    Args:
25
        detached_trainer_name_list: str list to get ray actor handles
26
        strategy:
27
        kl_coef: the coefficient of kl divergence loss
28
        sync_models_from_trainers: whether to sync models from trainers. If True, you must call sync_models_to_remote_makers() in trainers to sync models.
29
    """
30

31
    def __init__(
32
        self,
33
        detached_trainer_name_list: List[str],
34
        strategy_fn: Callable[[], Strategy],
35
        # a function returns (actor, critic, reward_model, initial_model)
36
        model_fn: Callable[[], Tuple[Actor, Critic, RewardModel, Actor]],
37
        env_info: Dict[str, str] = None,
38
        sync_models_from_trainers: bool = False,
39
        buffer_cpu_offload: bool = True,
40
        kl_coef: float = 0.1,
41
        callbacks: List[MakerCallback] = [],
42
        eval_performance: bool = False,
43
        debug: bool = False,
44
        update_lora_weights: bool = False,
45
        **generate_kwargs,
46
    ):
47
        # set environment variables
48
        if env_info:
49
            set_dist_env(env_info=env_info)
50
        self.target_trainer_list = []
51
        assert len(detached_trainer_name_list) > 0
52
        self._detached_trainer_name_list = detached_trainer_name_list
53
        self.strategy = strategy_fn()
54
        self.buffer_cpu_offload = buffer_cpu_offload
55
        self.kl_coef = kl_coef
56
        # init models
57
        with self.strategy.model_init_context():
58
            actor, critic, reward_model, initial_model = model_fn()
59
        self.generate_kwargs = _set_default_generate_kwargs(generate_kwargs, actor)
60
        if eval_performance:
61
            actor_numel = get_model_numel(actor)
62
            critic_numel = get_model_numel(critic)
63
            initial_model_numel = get_model_numel(initial_model)
64
            reward_model_numel = get_model_numel(reward_model)
65
            evaluator = ExperienceMakerPerformanceEvaluator(
66
                actor_numel, critic_numel, initial_model_numel, reward_model_numel
67
            )
68
            callbacks = callbacks + [evaluator]
69

70
        actor, critic, reward_model, initial_model = self.strategy.prepare(actor, critic, reward_model, initial_model)
71
        self.experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, self.kl_coef)
72
        self.callbacks = callbacks
73

74
        self._model_visit_lock = Lock()
75

76
        self._is_fully_initialized = not sync_models_from_trainers
77

78
        self._debug = debug
79
        self._update_lora_weights = update_lora_weights
80
        if self._update_lora_weights:
81
            self.actor_lora_constructor = LoRAConstructor()
82
            self.critic_lora_constructor = LoRAConstructor()
83

84
        self.target_auto_balance = False
85

86
        self._target_idx = 0
87

88
        if self._debug:
89
            print(f"[maker{get_rank()}] will send items to {self._detached_trainer_name_list}")
90
            if not self._is_fully_initialized:
91
                print(f"[maker{get_rank()}] Waiting for INIT")
92

93
    def _get_ready(self):
94
        while not self._fully_initialized():
95
            time.sleep(1.0)
96

97
    def _fully_initialized(self):
98
        return self._is_fully_initialized
99

100
    def _init_target_trainer_list(self):
101
        if len(self.target_trainer_list) > 0:
102
            return
103
        for name in self._detached_trainer_name_list:
104
            self.target_trainer_list.append(ray.get_actor(name, namespace=os.environ["RAY_NAMESPACE"]))
105

106
    # copy from ../trainer/base.py
107
    @ray.method(concurrency_group="compute")
108
    def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
109
        if isinstance(inputs, Tensor):
110
            return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
111
        elif isinstance(inputs, dict):
112
            return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
113
        else:
114
            raise ValueError(f'Unsupported input type "{type(inputs)}"')
115

116
    @ray.method(concurrency_group="experience_io")
117
    def _send_items(self, experience: Experience) -> None:
118
        self._init_target_trainer_list()
119
        items = split_experience_batch(experience)
120
        items_per_trainer = [[] for _ in range(len(self.target_trainer_list))]
121
        for item in items:
122
            items_per_trainer[self._target_idx].append(item)
123
            self._target_idx = (self._target_idx + 1) % len(self.target_trainer_list)
124
        for i, target_trainer in enumerate(self.target_trainer_list):
125
            if len(items_per_trainer[i]) > 0:
126
                target_trainer.buffer_extend.remote(items_per_trainer[i])
127

128
    def _inference_step(self, batch) -> None:
129
        self._on_batch_start()
130
        with self._model_visit_lock:
131
            self._on_make_experience_start()
132
            experience = self._make_experience(batch)
133
            self._on_make_experience_end(experience)
134
        self._on_send_start()
135
        if self.buffer_cpu_offload:
136
            experience.to_device("cpu")
137
        self._send_items(experience)
138
        self._on_send_end()
139
        self._on_batch_end()
140

141
    def workingloop(self, dataloader_fn: Callable[[], Iterable], num_epochs: int = 1, num_steps: int = 0):
142
        """Working loop of the experience maker.
143

144
        Args:
145
            dataloader_fn (Callable[[], Iterable]): A function that returns a dataloader.
146
            num_epochs (int, optional): Iterate the dataloader for number of epochs. Defaults to 1.
147
            num_steps (int, optional): Iterate the dataloader for number if steps. If this value > 0, num_epochs will be ignored. Defaults to 0.
148
        """
149
        self._get_ready()
150
        self._on_loop_start()
151
        dataloader = dataloader_fn()
152
        if num_steps > 0:
153
            # ignore num epochs
154
            it = iter(dataloader)
155
            for _ in tqdm(range(num_steps), desc="ExperienceMaker", disable=not is_rank_0()):
156
                try:
157
                    batch = next(it)
158
                except StopIteration:
159
                    it = iter(dataloader)
160
                    batch = next(it)
161
                self._inference_step(batch)
162
        else:
163
            with tqdm(total=num_epochs * len(dataloader), desc="ExperienceMaker", disable=not is_rank_0()) as pbar:
164
                for _ in range(num_epochs):
165
                    for batch in dataloader:
166
                        self._inference_step(batch)
167
                        pbar.update()
168
        self._on_loop_end()
169

170
    @ray.method(concurrency_group="model_io")
171
    def update_experience_maker(
172
        self,
173
        new_actor_state_dict: Dict[str, Any] = None,
174
        new_actor_lora_config_dict: Dict[str, Any] = None,
175
        new_critic_state_dict: Dict[str, Any] = None,
176
        new_critic_lora_config_dict: Dict[str, Any] = None,
177
        fully_update: bool = False,
178
        chunk_start: bool = None,
179
        chunk_end: bool = None,
180
    ):
181
        """
182
        called by trainer
183
        chunk_start: Set True at the first call. Before sending state_dict calls
184
        chunk_end: Set True at the last call. After sending state_dict calls.
185
        fully_update: Set True if you want to sync models when initializing
186

187
        TODO: load_state_dict integrate with model-sharding strategy
188
        """
189
        _watch_memory = self._debug
190
        if chunk_start:
191
            if self._debug:
192
                print("[maker] UPDATE ")
193
            if _watch_memory:
194
                tracemalloc.start()
195
            self._model_visit_lock.acquire()
196

197
        with torch.no_grad():
198
            if new_actor_state_dict is not None:
199
                if not self._update_lora_weights or fully_update:
200
                    self.experience_maker.actor.model.load_state_dict(new_actor_state_dict, strict=False)
201
                else:
202
                    new_actor_state_dict = state_dict_to(new_actor_state_dict, device=torch.cuda.current_device())
203
                    state_dict_increase = self.actor_lora_constructor.reconstruct_increase(
204
                        new_actor_state_dict, new_actor_lora_config_dict
205
                    )
206
                    self.actor_lora_constructor.load_state_dict_increase(
207
                        self.experience_maker.actor.model, state_dict_increase
208
                    )
209
            if new_critic_state_dict is not None:
210
                if not self._update_lora_weights or fully_update:
211
                    self.experience_maker.critic.load_state_dict(new_critic_state_dict, strict=False)
212
                else:
213
                    new_critic_state_dict = state_dict_to(new_critic_state_dict, device=torch.cuda.current_device())
214
                    state_dict_increase = self.critic_lora_constructor.reconstruct_increase(
215
                        new_critic_state_dict, new_critic_lora_config_dict
216
                    )
217
                    self.critic_lora_constructor.load_state_dict_increase(
218
                        self.experience_maker.critic, state_dict_increase
219
                    )
220

221
        # the lock must be released after both actor and critic being updated
222
        if chunk_end:
223
            self._model_visit_lock.release()
224
            if _watch_memory:
225
                current, peak = tracemalloc.get_traced_memory()
226
                print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
227
                tracemalloc.stop()
228
            if fully_update:
229
                self._is_fully_initialized = True
230

231
    def _on_make_experience_start(self) -> None:
232
        for callback in self.callbacks:
233
            callback.on_make_experience_start()
234

235
    def _on_make_experience_end(self, experience: Experience) -> None:
236
        for callback in self.callbacks:
237
            callback.on_make_experience_end(experience)
238

239
    def _on_loop_start(self) -> None:
240
        for callback in self.callbacks:
241
            callback.on_loop_start()
242

243
    def _on_loop_end(self) -> None:
244
        for callback in self.callbacks:
245
            callback.on_loop_end()
246

247
    def _on_send_start(self) -> None:
248
        for callback in self.callbacks:
249
            callback.on_send_start()
250

251
    def _on_send_end(self) -> None:
252
        for callback in self.callbacks:
253
            callback.on_send_end()
254

255
    def _on_batch_start(self) -> None:
256
        for callback in self.callbacks:
257
            callback.on_batch_start()
258

259
    def _on_batch_end(self) -> None:
260
        for callback in self.callbacks:
261
            callback.on_batch_end()
262

263

264
def _set_default_generate_kwargs(generate_kwargs: dict, actor: Actor) -> None:
265
    origin_model = actor.model
266
    new_kwargs = {**generate_kwargs}
267
    # use huggingface models method directly
268
    if "prepare_inputs_fn" not in generate_kwargs and hasattr(origin_model, "prepare_inputs_for_generation"):
269
        new_kwargs["prepare_inputs_fn"] = origin_model.prepare_inputs_for_generation
270

271
    if "update_model_kwargs_fn" not in generate_kwargs and hasattr(origin_model, "_update_model_kwargs_for_generation"):
272
        new_kwargs["update_model_kwargs_fn"] = origin_model._update_model_kwargs_for_generation
273

274
    return new_kwargs
275

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.