colossalai

Форк
0
153 строки · 5.6 Кб
1
import dataclasses
2
import math
3
import warnings
4
from typing import Optional
5

6
import loralib as lora
7
import torch
8
import torch.nn as nn
9
import torch.nn.functional as F
10

11

12
@dataclasses.dataclass
13
class LoRAManager:
14
    merge_weights: bool = False
15

16

17
LORA_MANAGER = LoRAManager()
18

19

20
class LoraLinear(lora.LoRALayer, nn.Module):
21
    """Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
22

23
    def __init__(
24
        self,
25
        weight: nn.Parameter,
26
        bias: Optional[nn.Parameter],
27
        r: int = 0,
28
        lora_alpha: int = 1,
29
        lora_dropout: float = 0.0,
30
        # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
31
        fan_in_fan_out: bool = False,
32
    ):
33
        nn.Module.__init__(self)
34
        lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
35
        self.weight = weight
36
        self.bias = bias
37

38
        out_features, in_features = weight.shape
39
        self.in_features = in_features
40
        self.out_features = out_features
41

42
        self.fan_in_fan_out = fan_in_fan_out
43
        # Actual trainable parameters
44
        if r > 0:
45
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
46
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
47
            self.scaling = self.lora_alpha / self.r
48
            # Freezing the pre-trained weight matrix
49
            self.weight.requires_grad = False
50
        self.reset_parameters()
51
        if fan_in_fan_out:
52
            self.weight.data = self.weight.data.T
53

54
    def reset_parameters(self):
55
        if hasattr(self, "lora_A"):
56
            # Initialize A with the default values for nn.Linear and set B to zero.
57
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
58
            nn.init.zeros_(self.lora_B)
59

60
    def train(self, mode: bool = True):
61
        def T(w):
62
            return w.T if self.fan_in_fan_out else w
63

64
        self.training = mode
65
        if LORA_MANAGER.merge_weights:
66
            if mode and self.merged:
67
                warnings.warn("Invoke module.train() would unmerge LoRA weights.")
68
                raise NotImplementedError("LoRA unmerge is not tested.")
69
                # Make sure that the weights are not merged
70
                if self.r > 0:
71
                    if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
72
                        # FIXME(csric): temporary fix
73
                        self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
74
                        self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
75
                        self.reset_parameters()
76
                    else:
77
                        self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
78
                self.merged = False
79
            elif not mode and not self.merged:
80
                warnings.warn("Invoke module.eval() would merge LoRA weights.")
81
                # Merge the weights and mark it
82
                if self.r > 0:
83
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
84
                    delattr(self, "lora_A")
85
                    delattr(self, "lora_B")
86
                self.merged = True
87

88
        return self
89

90
    def forward(self, x: torch.Tensor):
91
        def T(w):
92
            return w.T if self.fan_in_fan_out else w
93

94
        if self.r > 0 and not self.merged:
95
            result = F.linear(x, T(self.weight), bias=self.bias)
96
            if self.r > 0:
97
                result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
98
            return result
99
        else:
100
            return F.linear(x, T(self.weight), bias=self.bias)
101

102

103
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
104
    assert (
105
        lora_rank <= linear.in_features
106
    ), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
107
    lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
108
    return lora_linear
109

110

111
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
112
    for name, child in module.named_children():
113
        if isinstance(child, nn.Linear):
114
            setattr(module, name, _lora_linear_wrapper(child, lora_rank))
115
        else:
116
            _convert_to_lora_recursively(child, lora_rank)
117

118

119
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
120
    """Convert a torch.nn.Module to a LoRA module.
121

122
    Args:
123
        module (nn.Module): The module to convert.
124
        lora_rank (int): LoRA rank.
125

126
    Returns:
127
        nn.Module: The converted module.
128
    """
129
    if lora_rank <= 0:
130
        return module
131
    _convert_to_lora_recursively(module, lora_rank)
132
    lora.mark_only_lora_as_trainable(module, lora_train_bias)
133
    return module
134

135

136
class LoRAModule(nn.Module):
137
    """A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
138
    This class will convert all torch.nn.Linear layer to LoraLinear layer.
139

140
    Args:
141
        lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
142
        lora_train_bias (str, optional): Whether LoRA train biases.
143
            'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
144
            Defaults to 'none'.
145
    """
146

147
    def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
148
        super().__init__()
149
        self.lora_rank = lora_rank
150
        self.lora_train_bias = lora_train_bias
151

152
    def convert_to_lora(self) -> None:
153
        convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.