colossalai

Форк
0
136 строк · 5.4 Кб
1
import os
2
import random
3
from collections import OrderedDict
4
from typing import Callable, Optional
5

6
import numpy as np
7
import torch
8
import torch.distributed as dist
9
import torch.nn as nn
10
from coati.experience_buffer import ExperienceBuffer
11
from coati.models import Actor, Critic, RewardModel
12
from torch.utils.data import DataLoader
13
from transformers.modeling_utils import PreTrainedModel
14
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
15

16
from colossalai.booster.plugin import TorchDDPPlugin
17
from colossalai.booster.plugin.torch_ddp_plugin import TorchDDPModel
18

19
from .base import Strategy
20
from .sampler import DistributedSampler
21

22

23
# TODO Move this to a util.py   (Moving to ray.util introduces ringed import)
24
def get_grad_required_state_dict(model: nn.Module):
25
    state_dict = OrderedDict()
26
    for name, parameter in model.named_parameters():
27
        if parameter.requires_grad:
28
            state_dict[name] = parameter.detach()
29
    return state_dict
30

31

32
class DDPStrategy(Strategy):
33
    """
34
    Strategy for distributed training using torch.distributed.
35
    """
36

37
    def __init__(self, seed: int = 42, plugin_initializer: Callable = TorchDDPPlugin) -> None:
38
        self.seed = seed
39
        super().__init__(plugin_initializer)
40

41
    def _try_init_dist(self, force: bool = False) -> None:
42
        try:
43
            rank = int(os.environ["RANK"])
44
            local_rank = int(os.environ["LOCAL_RANK"])
45
            world_size = int(os.environ["WORLD_SIZE"])
46
            host = os.environ["MASTER_ADDR"]
47
            port = int(os.environ["MASTER_PORT"])
48
            dist.init_process_group("nccl", init_method=f"tcp://[{host}]:{port}", world_size=world_size, rank=rank)
49
            torch.cuda.set_device(local_rank)
50
        except KeyError as e:
51
            if force:
52
                raise RuntimeError(
53
                    f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
54
                )
55
        except Exception as e:
56
            if force:
57
                raise e
58

59
    def _post_init(self) -> None:
60
        assert isinstance(self.plugin, TorchDDPPlugin), f"{type(self).__name__}'s plugin is not initialized properly."
61

62
    def setup_distributed(self) -> None:
63
        self._try_init_dist(force=True)
64
        self.set_seed(self.seed)
65

66
    def set_seed(self, seed: int) -> None:
67
        random.seed(seed)
68
        np.random.seed(seed)
69
        torch.manual_seed(seed)
70

71
    def setup_dataloader(self, data_buffer: ExperienceBuffer, pin_memory: bool = False) -> DataLoader:
72
        return self.plugin.prepare_dataloader(
73
            data_buffer,
74
            batch_size=data_buffer.sample_batch_size,
75
            shuffle=True,
76
            drop_last=True,
77
            pin_memory=pin_memory,
78
            collate_fn=data_buffer.collate_fn,
79
        )
80

81
    def setup_sampler(self, dataset) -> DistributedSampler:
82
        # FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
83
        return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
84

85
    def unwrap_model(self, model: nn.Module) -> nn.Module:
86
        assert isinstance(model, TorchDDPModel), "model is not wrapped by TorchDDPModel."
87
        return model.unwrap()
88

89
    def save_pretrained(
90
        self, model: nn.Module, path: str, shard: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None
91
    ) -> None:
92
        if dist.get_rank() == 0:
93
            unwrapped_model = self.unwrap_model(model)
94
            assert isinstance(unwrapped_model, (Actor, Critic, RewardModel))
95
            pretrained_model = unwrapped_model.model
96
            assert isinstance(pretrained_model, PreTrainedModel)
97
            # HACK: only use hf save_pretrained to save config
98
            pretrained_model.save_pretrained(path, save_function=lambda *args, **kwargs: None)
99
            if tokenizer is not None:
100
                tokenizer.save_pretrained(path)
101

102
        model_path = os.path.join(path, "pytorch_model.bin")
103
        self.save_model(model, model_path, shard=shard)
104
        def _replace_keys(model_path: str, replace_fn: Callable):
105
            state_dict = torch.load(model_path, map_location="cpu")
106
            state_dict = {replace_fn(k): v for k, v in state_dict.items()}
107
            torch.save(state_dict, model_path)
108
        # FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
109
        # HACK: rename keys of pytorch_model.bin
110
        if dist.get_rank() == 0:
111
            _replace_keys(model_path, lambda k: k.replace("model.", "", 1))
112

113

114
    def get_model_state_dict_shard(self, model: nn.Module, **config):
115
        # TODO: implement sharding on naive strategy
116
        model = self.unwrap_model(model)
117
        if "requires_grad_only" in config and config["requires_grad_only"] == True:
118
            state_dict = get_grad_required_state_dict(model)
119
        else:
120
            state_dict = model.state_dict()
121

122
        if "shard_size" in config:
123
            shard_size = config["shard_size"]
124
            accumulate_size = 0
125
            state_dict_shard = OrderedDict()
126
            for name, param in state_dict.items():
127
                state_dict_shard[name] = param
128
                accumulate_size += param.numel() * param.element_size()
129
                if accumulate_size >= shard_size:
130
                    accumulate_size = 0
131
                    yield state_dict_shard
132
                    state_dict_shard = OrderedDict()
133
            if accumulate_size > 0:
134
                yield state_dict_shard
135
        else:
136
            yield state_dict
137

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.