colossalai

Форк
0
209 строк · 9.0 Кб
1
import warnings
2
from typing import Optional
3

4
import torch.nn as nn
5

6
import colossalai
7
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin
8
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroModel
9
from colossalai.zero.gemini.gemini_ddp import GeminiDDP
10

11
from .ddp import DDPStrategy
12

13

14
class LowLevelZeroStrategy(DDPStrategy):
15
    """
16
        The strategy for training with ColossalAI.
17

18
    Args:
19
        stage(int): The stage to use in ZeRO. Choose in (1, 2)
20
        precision(str): The precision to use. Choose in ('fp32', 'fp16').
21
        seed(int): The seed for the random number generator.
22
        placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
23
                          If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
24
                          If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
25
        reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
26
        overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
27
        initial_scale(float): The initial scale for the optimizer.
28
        growth_factor(float): The growth factor for the optimizer.
29
        backoff_factor(float): The backoff factor for the optimizer.
30
        growth_interval(int): The growth interval for the optimizer.
31
        hysteresis(int): The hysteresis for the optimizer.
32
        min_scale(float): The minimum scale for the optimizer.
33
        max_scale(float): The maximum scale for the optimizer.
34
        max_norm(float): The maximum norm for the optimizer.
35
        norm_type(float): The norm type for the optimizer.
36

37
    """
38

39
    def __init__(
40
        self,
41
        stage: int = 2,
42
        precision: str = "fp16",
43
        seed: int = 42,
44
        placement_policy: str = "cuda",
45
        reduce_bucket_size: int = 12 * 1024**2,  # only for stage 1&2
46
        overlap_communication: bool = True,  # only for stage 1&2
47
        initial_scale: float = 2**16,
48
        growth_factor: float = 2,
49
        backoff_factor: float = 0.5,
50
        growth_interval: int = 1000,
51
        hysteresis: int = 2,
52
        min_scale: float = 1,
53
        max_scale: float = 2**32,
54
        max_norm: float = 0.0,
55
        norm_type: float = 2.0,
56
    ) -> None:
57
        assert stage in (1, 2), f'Unsupported stage "{stage}"'
58
        assert placement_policy in ("cpu", "cuda"), f'Unsupported placement policy "{placement_policy}"'
59
        assert precision in ("fp32", "fp16"), f'Unsupported precision "{precision}"'
60

61
        plugin_initializer = lambda: LowLevelZeroPlugin(
62
            stage=stage,
63
            precision=precision,
64
            reduce_bucket_size_in_m=reduce_bucket_size,
65
            overlap_communication=overlap_communication,
66
            cpu_offload=(placement_policy == "cpu"),
67
            initial_scale=initial_scale,
68
            growth_factor=growth_factor,
69
            backoff_factor=backoff_factor,
70
            growth_interval=growth_interval,
71
            hysteresis=hysteresis,
72
            min_scale=min_scale,
73
            max_scale=max_scale,
74
            max_norm=max_norm,
75
            norm_type=norm_type,
76
        )
77

78
        super().__init__(seed, plugin_initializer)
79

80
    def _post_init(self) -> None:
81
        assert isinstance(
82
            self.plugin, LowLevelZeroPlugin
83
        ), f"{type(self).__name__}'s plugin is not initialized properly."
84

85
    def setup_distributed(self) -> None:
86
        colossalai.launch_from_torch({}, seed=self.seed)
87

88
    def unwrap_model(self, model: nn.Module) -> nn.Module:
89
        assert isinstance(model, LowLevelZeroModel)
90
        return model.module
91

92
    def get_model_state_dict_shard(self, model: nn.Module, **config):
93
        assert isinstance(model, LowLevelZeroModel)
94
        yield from model.state_dict_shard(max_shard_size=1024, only_rank_0=False)
95

96

97
class GeminiStrategy(DDPStrategy):
98
    """
99
        The strategy for training with ColossalAI.
100

101
    Args:
102
        seed(int): The seed for the random number generator.
103
        shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
104
            This is not compatible with `from_pretrained()`. We temporarily disable this and will support it in the future.
105
        placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
106
                          If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
107
                          If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
108
        pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
109
        force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
110
        search_range_m(int): The number of search range for the chunk size, divided by 2^20. Only for ZeRO-3.
111
        hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
112
        min_chunk_size_m(float): The minimum chunk size divided by 2^20. Only for ZeRO-3.
113
        gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
114
        initial_scale(float): The initial scale for the optimizer.
115
        growth_factor(float): The growth factor for the optimizer.
116
        backoff_factor(float): The backoff factor for the optimizer.
117
        growth_interval(int): The growth interval for the optimizer.
118
        hysteresis(int): The hysteresis for the optimizer.
119
        min_scale(float): The minimum scale for the optimizer.
120
        max_scale(float): The maximum scale for the optimizer.
121
        max_norm(float): The maximum norm for the optimizer.
122
        norm_type(float): The norm type for the optimizer.
123

124
    """
125

126
    def __init__(
127
        self,
128
        seed: int = 42,
129
        shard_init: bool = False,  # only for stage 3
130
        placement_policy: str = "auto",
131
        shard_param_frac: float = 1.0,  # only for static placement
132
        offload_optim_frac: float = 0.0,  # only for static placement
133
        offload_param_frac: float = 0.0,  # only for static placement
134
        pin_memory: bool = True,  # only for stage 3
135
        force_outputs_fp32: bool = False,  # only for stage 3
136
        search_range_m: int = 32,  # only for stage 3
137
        hidden_dim: Optional[int] = None,  # only for stage 3
138
        min_chunk_size_m: float = 32,  # only for stage 3
139
        gpu_margin_mem_ratio: float = 0.0,  # only for stage 3
140
        initial_scale: float = 2**16,
141
        growth_factor: float = 2,
142
        backoff_factor: float = 0.5,
143
        growth_interval: int = 1000,
144
        hysteresis: int = 2,
145
        min_scale: float = 1,
146
        max_scale: float = 2**32,
147
        max_norm: float = 0.0,
148
        norm_type: float = 2.0,
149
    ) -> None:
150
        # TODO(ver217): support shard_init when using from_pretrained()
151
        if shard_init:
152
            warnings.warn(
153
                f"Shard init is not supported model.from_pretrained() yet. "
154
                "Please load weights after strategy.prepare()"
155
            )
156
        self.shard_init = shard_init
157

158
        warnings.warn(f"Stage 3 only supports fp16. Precision is set to fp16.")
159

160
        # colossalai has changed api for get_current_device in 0.3.4 version or newer
161
        try:
162
            from colossalai.accelerator import get_accelerator
163

164
            chunk_init_device = get_accelerator().get_current_device()
165
        except:
166
            from colossalai.utils import get_current_device
167

168
            chunk_init_device = get_current_device()
169

170
        # NOTE: dist should be initialized before calling get_current_device()
171
        plugin_initializer = lambda: GeminiPlugin(
172
            chunk_init_device=chunk_init_device,
173
            placement_policy=placement_policy,
174
            shard_param_frac=shard_param_frac,
175
            offload_optim_frac=offload_optim_frac,
176
            offload_param_frac=offload_param_frac,
177
            precision="fp16",
178
            pin_memory=pin_memory,
179
            force_outputs_fp32=force_outputs_fp32,
180
            strict_ddp_mode=shard_init,
181
            search_range_m=search_range_m,
182
            hidden_dim=hidden_dim,
183
            min_chunk_size_m=min_chunk_size_m,
184
            gpu_margin_mem_ratio=gpu_margin_mem_ratio,
185
            initial_scale=initial_scale,
186
            growth_factor=growth_factor,
187
            backoff_factor=backoff_factor,
188
            growth_interval=growth_interval,
189
            hysteresis=hysteresis,
190
            min_scale=min_scale,
191
            max_scale=max_scale,
192
            max_norm=max_norm,
193
            norm_type=norm_type,
194
        )
195

196
        super().__init__(seed, plugin_initializer)
197

198
    def _post_init(self) -> None:
199
        assert isinstance(self.plugin, GeminiPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
200

201
    def setup_distributed(self) -> None:
202
        colossalai.launch_from_torch({}, seed=self.seed)
203

204
    def model_init_context(self):
205
        return super().model_init_context()
206

207
    def unwrap_model(self, model: nn.Module) -> nn.Module:
208
        assert isinstance(model, GeminiDDP)
209
        return model.module
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.