colossalai

Форк
0
137 строк · 5.2 Кб
1
from abc import ABC, abstractmethod
2
from contextlib import nullcontext
3
from typing import Callable, Dict, List, Optional, Tuple, Union
4

5
import torch
6
import torch.nn as nn
7
from coati.experience_buffer import ExperienceBuffer
8
from torch.optim import Optimizer
9
from torch.utils.data import DataLoader
10
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
11

12
from colossalai.booster import Booster
13
from colossalai.booster.plugin import Plugin
14

15
from .sampler import DistributedSampler
16

17
_BoostArgSpec = Union[nn.Module, Tuple[nn.Module, Optimizer], Dict]
18

19

20
class Strategy(ABC):
21
    """
22
    Base class for training strategies.
23
    """
24

25
    def __init__(self, plugin_initializer: Callable[..., Optional[Plugin]] = lambda: None) -> None:
26
        super().__init__()
27
        # NOTE: dist must be initialized before Booster
28
        self.setup_distributed()
29
        self.plugin = plugin_initializer()
30
        self.booster = Booster(plugin=self.plugin)
31
        self._post_init()
32

33
    @abstractmethod
34
    def _post_init(self) -> None:
35
        pass
36

37
    def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
38
        self.booster.backward(loss, optimizer)
39

40
    def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
41
        optimizer.step()
42

43
    @abstractmethod
44
    def setup_distributed(self) -> None:
45
        pass
46

47
    @abstractmethod
48
    def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
49
        pass
50

51
    def model_init_context(self):
52
        return nullcontext()
53

54
    def prepare(self, *boost_args: _BoostArgSpec) -> Union[List[_BoostArgSpec], _BoostArgSpec]:
55
        """Prepare [model | (model, optimizer) | Dict] based on each strategy.
56
        NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
57

58
        Example::
59
            >>> # e.g., include lr_scheduler
60
            >>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
61
            >>> # when fine-tuning actor and critic
62
            >>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
63
            >>> # or when training reward model
64
            >>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
65
            >>> # or just inference
66
            >>> actor, critic = strategy.prepare(actor, critic)
67

68
        Returns:
69
            Union[List[_BoostArgSpec], _BoostArgSpec]: [model | (model, optimizer) | Dict] in the original order.
70
        """
71

72
        rets = []
73
        for arg in boost_args:
74
            if isinstance(arg, nn.Module):
75
                model, *_ = self.booster.boost(arg)
76
                rets.append(model)
77
            elif isinstance(arg, tuple):
78
                try:
79
                    model, optimizer = arg
80
                except ValueError:
81
                    raise RuntimeError(f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"')
82
                model, optimizer, *_ = self.booster.boost(model=model, optimizer=optimizer)
83
                rets.append((model, optimizer))
84
            elif isinstance(arg, Dict):
85
                model, optimizer, criterion, dataloader, lr_scheduler = self.booster.boost(**arg)
86
                boost_result = dict(
87
                    model=model,
88
                    optimizer=optimizer,
89
                    criterion=criterion,
90
                    dataloader=dataloader,
91
                    lr_scheduler=lr_scheduler,
92
                )
93
                # remove None values
94
                boost_result = {key: value for key, value in boost_result.items() if value is not None}
95
                rets.append(boost_result)
96
            else:
97
                raise RuntimeError(f"Type {type(arg)} is not supported")
98

99
        return rets[0] if len(rets) == 1 else rets
100

101
    @staticmethod
102
    def unwrap_model(model: nn.Module) -> nn.Module:
103
        """Get the unwrapped model from a wrapped model made by Strategy.prepare.
104

105
        Args:
106
            model (nn.Module): the model to unwrap
107

108
        Returns:
109
            nn.Module: the original model
110
        """
111
        return model
112

113
    def save_model(self, model: nn.Module, path: str, shard: bool = False, **kwargs) -> None:
114
        self.booster.save_model(model, path, shard=shard, **kwargs)
115

116
    def load_model(self, model: nn.Module, path: str, strict: bool = True) -> None:
117
        self.booster.load_model(model, path, strict)
118

119
    def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False, **kwargs) -> None:
120
        self.booster.save_optimizer(optimizer, path, shard=not only_rank0, **kwargs)
121

122
    def load_optimizer(self, optimizer: Optimizer, path: str) -> None:
123
        self.booster.load_optimizer(optimizer, path)
124

125
    def setup_sampler(self, dataset) -> DistributedSampler:
126
        # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
127
        return DistributedSampler(dataset, 1, 0)
128

129
    @abstractmethod
130
    def save_pretrained(
131
        self, model: nn.Module, path: str, only_rank0: bool = True, tokenizer: Optional[PreTrainedTokenizerBase] = None
132
    ) -> None:
133
        pass
134

135
    @abstractmethod
136
    def get_model_state_dict_shard(self, model: nn.Module, **config):
137
        pass
138

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.