paddlenlp

Форк
0
219 строк · 9.2 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import warnings
16

17
import paddle
18
from paddle import _C_ops
19
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
20
    device_guard,
21
)
22
from paddle.base import core, framework
23
from paddle.base.framework import Variable
24
from paddle.optimizer import Adam, AdamW, Momentum
25
from ppfleetx.distributed.apis import env
26
from ppfleetx.utils.tensor_fusion_helper import fused_parameters
27

28
__all__ = [
29
    "Adam",
30
    "AdamW",
31
    "Momentum",
32
    "FusedAdamW",
33
    "FusedOffloadAdamW",
34
]
35

36

37
class FusedAdamW(paddle.optimizer.AdamW):
38
    def __init__(self, learning_rate, parameters, grad_clip, **config):
39
        tensor_fusion = config.pop("tensor_fusion", False)
40

41
        if paddle.distributed.get_world_size() > 1:
42
            hcg = env.get_hcg()
43
            sharding_size = hcg.get_sharding_parallel_world_size()
44

45
        if tensor_fusion:
46
            self.decay_fused_tensors, self.all_fused_tensors = fused_parameters(parameters, sharding_size > 1)
47
            decay_params = [p.name for p in self.decay_fused_tensors]
48
        else:
49
            decay_params = [p.name for p in parameters if not any(nd in p.name for nd in ["bias", "norm", "b_0"])]
50

51
        apply_decay_param_fun = lambda x: x in decay_params
52

53
        super().__init__(
54
            learning_rate=learning_rate,
55
            parameters=self.all_fused_tensors if tensor_fusion else parameters,
56
            grad_clip=grad_clip,
57
            apply_decay_param_fun=apply_decay_param_fun,
58
            **config,
59
        )
60

61

62
class FusedOffloadAdamW(paddle.optimizer.AdamW):
63
    def __init__(self, learning_rate, parameters, grad_clip, **config):
64
        tensor_fusion = config.pop("tensor_fusion", False)
65

66
        if paddle.distributed.get_world_size() > 1:
67
            hcg = env.get_hcg()
68
            sharding_size = hcg.get_sharding_parallel_world_size()
69

70
        if tensor_fusion:
71
            self.decay_fused_tensors, self.all_fused_tensors = fused_parameters(parameters, sharding_size > 1)
72
            decay_params = [p.name for p in self.decay_fused_tensors]
73
        else:
74
            decay_params = [p.name for p in parameters if not any(nd in p.name for nd in ["bias", "norm", "b_0"])]
75

76
        apply_decay_param_fun = lambda x: x in decay_params
77

78
        super().__init__(
79
            learning_rate=learning_rate,
80
            parameters=self.all_fused_tensors if tensor_fusion else parameters,
81
            grad_clip=grad_clip,
82
            apply_decay_param_fun=apply_decay_param_fun,
83
            **config,
84
        )
85

86
        self._already_create_accumulater = set()
87
        self._dev_id = 0 if paddle.get_device() == "cpu" else int(paddle.get_device().split(":")[1])
88

89
        for p in parameters:
90
            if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
91
                self._master_weights[p.name] = core.eager.Tensor(
92
                    name=p.name + "_fp32_master",
93
                    value=p.numpy(),
94
                    place=core.CPUPlace(),
95
                    stop_gradient=True,
96
                ).cast(paddle.float32)
97

98
    def _add_moments_pows(self, p):
99
        acc_dtype = p.dtype
100
        if self._is_dtype_fp16_or_bf16(acc_dtype):
101
            acc_dtype = core.VarDesc.VarType.FP32
102
        self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype, device="cpu")
103
        self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, device="cpu")
104
        self._add_accumulator(
105
            name=self._beta1_pow_acc_str,
106
            param=p,
107
            dtype=acc_dtype,
108
            fill_value=0.9 if isinstance(self._beta1, Variable) else self._beta1,
109
            shape=[1],
110
            type=core.VarDesc.VarType.LOD_TENSOR,
111
            device="cpu",
112
        )
113
        self._add_accumulator(
114
            name=self._beta2_pow_acc_str,
115
            param=p,
116
            dtype=acc_dtype,
117
            fill_value=0.999 if isinstance(self._beta2, Variable) else self._beta2,
118
            shape=[1],
119
            type=core.VarDesc.VarType.LOD_TENSOR,
120
            device="cpu",
121
        )
122

123
    def _create_accumulators(self, block, parameters):
124
        with device_guard():
125
            assert isinstance(block, framework.Block)
126
            if isinstance(parameters, dict):
127
                parameters = self._update_param_group(parameters)
128

129
            # Create accumulator tensors for first and second moments
130
            for p in parameters:
131
                if p.name in self._already_create_accumulater:
132
                    continue
133
                if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype):
134
                    master_p = self._create_master_weight(p)
135
                    self._add_moments_pows(master_p)
136
                    self._already_create_accumulater.add(p.name)
137
                    continue
138
                if self._is_dtype_fp16_or_bf16(p.dtype) and not self._multi_precision:
139
                    warnings.warn(
140
                        "Accumulating with FP16 or BF16 in optimizer can lead to poor accuracy or slow convergence."
141
                        "Consider using multi_precision=True option of the Adam optimizer."
142
                    )
143
                self._add_moments_pows(p)
144
                self._already_create_accumulater.add(p.name)
145

146
    def _get_accumulator_master(self, name, param):
147
        """Utility function to fetch an accumulator for a parameter
148
        Args:
149
            name: name of the accumulator
150
            param: parameter variable for which accumulator is to be fetched
151
        Returns:
152
            accumulator variable for the parameter
153
        """
154
        if self._name is not None:
155
            name = self._name + "_" + name
156
        find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param.dtype)
157
        target_param = self._master_weights[param.name] if find_master else param
158
        target_name = target_param.name
159
        if name not in self._accumulators or target_name not in self._accumulators[name]:
160
            raise Exception("Accumulator {} does not exist for parameter {}".format(name, target_name))
161
        return self._accumulators[name][target_name]
162

163
    def _append_optimize_op(self, block, param_and_grad):
164
        with device_guard():
165
            assert isinstance(block, framework.Block)
166
            if isinstance(param_and_grad, dict):
167
                param_and_grad = self._update_param_group(param_and_grad)
168
            param, grad = param_and_grad
169

170
            # Whether we should do weight decay for the parameter.
171
            with_decay = True
172
            if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name):
173
                with_decay = False
174

175
            moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0])
176
            moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0])
177
            beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0])
178
            beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0])
179
            find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype)
180
            master_weight = self._master_weights[param_and_grad[0].name] if find_master else None
181
            lr = self._create_param_lr(param_and_grad)
182

183
            # create the adamw optimize op
184
            if framework.in_dygraph_mode():
185
                lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
186

187
                _beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0)
188
                _beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0)
189

190
                origin_dtype = param_and_grad[0].dtype
191
                cpu_fp32_param = param_and_grad[0].cpu().cast(paddle.float32)
192
                cpu_fp32_grad = param_and_grad[1].cpu().cast(paddle.float32)
193

194
                _, _, _, _, _, _ = _C_ops.adamw_(
195
                    cpu_fp32_param,
196
                    cpu_fp32_grad,
197
                    lr.cpu(),
198
                    moment1.cpu(),
199
                    moment2.cpu(),
200
                    beta1_pow_acc.cpu(),
201
                    beta2_pow_acc.cpu(),
202
                    master_weight.cpu() if master_weight is not None else None,
203
                    None,
204
                    _beta1,
205
                    _beta2,
206
                    self._epsilon,
207
                    lr_ratio_,
208
                    self._weight_decay,
209
                    with_decay,
210
                    self._lazy_mode,
211
                    1000,
212
                    find_master,
213
                    False,
214
                )
215

216
                param_and_grad[0]._clear_data()
217
                cpu_fp32_param.cuda(self._dev_id).cast(origin_dtype)._share_buffer_to(param_and_grad[0])
218

219
                return None
220

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.