paddlenlp

Форк
0
154 строки · 4.3 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import random
17
from collections import namedtuple
18

19
import numpy as np
20
import paddle
21
import paddle.distributed as dist
22
import paddle.distributed.auto_parallel as auto
23

24
from paddlenlp.ops import Topology
25
from paddlenlp.trainer.trainer_utils import _get_distributed_seeds
26
from ppfleetx.utils.log import logger
27

28
_mesh = None
29

30

31
def get_mesh():
32
    global _mesh
33
    if _mesh is None and paddle.distributed.get_world_size() == 1:
34
        set_mesh(
35
            Mesh(
36
                get_local_rank(),
37
                1,
38
                1,
39
                1,
40
            )
41
        )
42
    return _mesh
43

44

45
def set_mesh(mesh):
46
    global _mesh
47
    _mesh = mesh
48

49

50
class Mesh:
51
    def __init__(self, rank, dp_degree, mp_degree, pp_degree):
52
        self._dp_dim = "dp" if dp_degree > 1 else None
53
        self._mp_dim = "mp" if mp_degree > 1 else None
54
        self._dp_degree = dp_degree
55
        self._mp_degree = mp_degree
56
        self._pp_degree = pp_degree
57

58
        arr = np.arange(0, pp_degree * dp_degree * mp_degree).reshape([dp_degree, pp_degree, mp_degree])
59
        arr = arr.transpose(1, 0, 2)
60
        self.world_process_mesh = auto.ProcessMesh(arr, dim_names=["pp", "dp", "mp"])
61
        self.g_process_mesh = auto.ProcessMesh(list(range(pp_degree * dp_degree * mp_degree)))
62
        ipp, idp, imp = np.where(arr == rank)
63
        ipp = ipp[0]
64
        idp = idp[0]
65
        imp = imp[0]
66

67
        if dp_degree > 1 and mp_degree > 1:
68
            self.pp_process_mesh = self.world_process_mesh
69
        elif mp_degree > 1:
70
            self.pp_process_mesh = self.world_process_mesh[:, idp, :]
71
        else:
72
            self.pp_process_mesh = self.world_process_mesh[:, :, imp]
73

74
    @property
75
    def dp_degree(self):
76
        return self._dp_degree
77

78
    @property
79
    def mp_degree(self):
80
        return self._mp_degree
81

82
    # TODO(JZ-LIANG) Support SP as an independent mesh axis
83
    @property
84
    def sp_degree(self):
85
        return self._mp_degree
86

87
    @property
88
    def pp_degree(self):
89
        return self._pp_degree
90

91
    @property
92
    def dp_dim(self):
93
        return self._dp_dim
94

95
    @property
96
    def mp_dim(self):
97
        return self._mp_dim
98

99
    # TODO(JZ-LIANG) Support SP as an independent mesh axis
100
    @property
101
    def sp_dim(self):
102
        return self._mp_dim
103

104
    def __getitem__(self, idx):
105
        return self.pp_process_mesh[idx]
106

107

108
def init_dist_env(config):
109
    paddle.set_device(config.Global.device)
110

111
    mesh = Mesh(
112
        get_local_rank(),
113
        config.Distributed.dp_degree,
114
        config.Distributed.mp_degree,
115
        config.Distributed.pp_degree,
116
    )
117
    set_mesh(mesh)
118
    paddle.distributed.fleet.init(is_collective=True)
119

120

121
def get_local_rank():
122
    return int(os.getenv("PADDLE_RANK_IN_NODE", 0))
123

124

125
def set_seed(seed):
126
    topo = None
127
    if dist.get_world_size() > 1:
128

129
        topo = Topology(
130
            dist.get_rank(), 
131
            dist.get_world_size(),
132
            dp_degree=_mesh.dp_degree, 
133
            pp_degree=_mesh.pp_degree,
134
            mp_degree=_mesh.mp_degree,
135
            sharding_degree=1, # auto_parallel's sharding is not orthogonal with dp, mp and pp
136
        )
137

138
    global_seed, local_seed, random_seed = _get_distributed_seeds(seed, topo)
139

140
    # NOTE: add (1024 + world_size) to seed for CI cases
141
    global_seed = global_seed + 1024 + paddle.distributed.get_world_size()
142
    local_seed = local_seed + 1024 + paddle.distributed.get_world_size()
143

144
    paddle.seed(global_seed)
145
    random.seed(random_seed)
146
    np.random.seed(random_seed)
147

148
    logger.info("The global seed is set to {}, local seed is set to {} and "
149
                "random seed is set to {}.".format(global_seed, local_seed, random_seed))
150

151
    global _seed
152
    global _dp_seed
153
    _seed = seed
154
    _dp_seed = global_seed
155

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.