paddlenlp

Форк
0
156 строк · 4.9 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import distutils.util
16
import importlib
17
import os
18

19
import paddle
20
from paddle import _C_ops
21

22
OriginLayerNorm = paddle.nn.LayerNorm
23
origin_linear = paddle.incubate.nn.functional.fused_linear
24

25

26
def try_import(module_name, func_name=None):
27
    if func_name is None:
28
        func_name = module_name
29
    try:
30
        m = importlib.import_module(module_name)
31
        return m
32
    # return getattr(m, func_name)
33
    except ImportError:
34
        return None
35

36

37
fast_ln_lib = try_import("fast_ln")
38
fused_ln_lib = try_import("fused_ln")
39

40
if fast_ln_lib is not None:
41
    fast_ln = fast_ln_lib.fast_ln
42

43
if fused_ln_lib is not None:
44
    fused_ln = fused_ln_lib.fused_ln
45
    fused_rms_norm = fused_ln_lib.fused_rms_norm
46

47

48
def check_normalized_shape(normalized_shape):
49
    if isinstance(normalized_shape, (list, tuple)):
50
        assert len(normalized_shape) == 1
51

52

53
class FusedLayerNorm(OriginLayerNorm):
54
    def __init__(self,
55
                 normalized_shape,
56
                 epsilon=1e-05,
57
                 weight_attr=None,
58
                 bias_attr=None,
59
                 name=None):
60
        super().__init__(
61
            normalized_shape=normalized_shape,
62
            epsilon=epsilon,
63
            weight_attr=weight_attr,
64
            bias_attr=bias_attr)
65
        check_normalized_shape(self._normalized_shape)
66

67
    def forward(self, input):
68
        return fused_ln(input, self.weight, self.bias, self._epsilon)[0]
69

70

71
class FusedRMSNorm(OriginLayerNorm):
72
    def __init__(self,
73
                 normalized_shape,
74
                 epsilon=1e-05,
75
                 weight_attr=None,
76
                 name=None):
77
        super().__init__(
78
            normalized_shape=normalized_shape,
79
            epsilon=epsilon,
80
            weight_attr=weight_attr,
81
            bias_attr=False)
82
        check_normalized_shape(self._normalized_shape)
83

84
    def forward(self, input):
85
        return fused_rms_norm(input, self.weight, self._epsilon)[0]
86

87

88
class FastLayerNorm(OriginLayerNorm):
89
    def __init__(self,
90
                 normalized_shape,
91
                 epsilon=1e-05,
92
                 weight_attr=None,
93
                 bias_attr=None,
94
                 name=None):
95
        super().__init__(
96
            normalized_shape=normalized_shape,
97
            epsilon=epsilon,
98
            weight_attr=weight_attr,
99
            bias_attr=bias_attr)
100
        check_normalized_shape(self._normalized_shape)
101

102
    def forward(self, input):
103
        return fast_ln(input, self.weight, self.bias, self._epsilon)[0]
104

105

106
class FusedLinearWithGradAdd(paddle.autograd.PyLayer):
107
    @staticmethod
108
    def forward(ctx, x, weight, bias=None, name=None):
109
        y = origin_linear(x, weight, bias)
110
        ctx.save_for_backward(x, weight, bias)
111
        return y
112

113
    @staticmethod
114
    def backward(ctx, y_grad):
115
        x, weight, bias = ctx.saved_tensor()
116
        x_grad = paddle.matmul(y_grad, weight, transpose_y=True)
117

118
        if bias is None:
119
            if hasattr(weight, "main_grad"):
120
                weight.main_grad, _ = _C_ops.fused_linear_param_grad_add(
121
                    x, y_grad, weight.main_grad, None, True)
122
                return x_grad, None
123
            else:
124
                weight_grad, _ = _C_ops.fused_linear_param_grad_add(
125
                    x, y_grad, None, None, False)
126
                return x_grad, weight_grad
127

128
        if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
129
            weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add(
130
                x, y_grad, weight.main_grad, bias.main_grad, True)
131
            return x_grad, None, None
132
        else:
133
            weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(
134
                x, y_grad, None, None, False)
135
            return x_grad, weight_grad, bias_grad
136

137

138
def strtobool(s):
139
    return True if distutils.util.strtobool(s) else False
140

141

142
def get_env(env_name, default_value=False):
143
    return strtobool(os.getenv(env_name, str(default_value)))
144

145

146
def mock_layers():
147
    if get_env("USE_FAST_LN"):
148
        paddle.nn.LayerNorm = FastLayerNorm
149
    elif get_env("USE_FUSED_LN"):
150
        paddle.nn.LayerNorm = FusedLayerNorm
151
    elif get_env("USE_FUSED_RMS_NORM"):
152
        paddle.nn.LayerNorm = FusedRMSNorm
153

154
    if get_env("USE_LINEAR_WITH_GRAD_ADD"):
155
        paddle.nn.functional.linear = FusedLinearWithGradAdd.apply
156
        paddle.incubate.nn.functional.fused_linear = FusedLinearWithGradAdd.apply
157

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.