colossalai
153 строки · 5.6 Кб
1import dataclasses
2import math
3import warnings
4from typing import Optional
5
6import loralib as lora
7import torch
8import torch.nn as nn
9import torch.nn.functional as F
10
11
12@dataclasses.dataclass
13class LoRAManager:
14merge_weights: bool = False
15
16
17LORA_MANAGER = LoRAManager()
18
19
20class LoraLinear(lora.LoRALayer, nn.Module):
21"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
22
23def __init__(
24self,
25weight: nn.Parameter,
26bias: Optional[nn.Parameter],
27r: int = 0,
28lora_alpha: int = 1,
29lora_dropout: float = 0.0,
30# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
31fan_in_fan_out: bool = False,
32):
33nn.Module.__init__(self)
34lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
35self.weight = weight
36self.bias = bias
37
38out_features, in_features = weight.shape
39self.in_features = in_features
40self.out_features = out_features
41
42self.fan_in_fan_out = fan_in_fan_out
43# Actual trainable parameters
44if r > 0:
45self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
46self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
47self.scaling = self.lora_alpha / self.r
48# Freezing the pre-trained weight matrix
49self.weight.requires_grad = False
50self.reset_parameters()
51if fan_in_fan_out:
52self.weight.data = self.weight.data.T
53
54def reset_parameters(self):
55if hasattr(self, "lora_A"):
56# Initialize A with the default values for nn.Linear and set B to zero.
57nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
58nn.init.zeros_(self.lora_B)
59
60def train(self, mode: bool = True):
61def T(w):
62return w.T if self.fan_in_fan_out else w
63
64self.training = mode
65if LORA_MANAGER.merge_weights:
66if mode and self.merged:
67warnings.warn("Invoke module.train() would unmerge LoRA weights.")
68raise NotImplementedError("LoRA unmerge is not tested.")
69# Make sure that the weights are not merged
70if self.r > 0:
71if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
72# FIXME(csric): temporary fix
73self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
74self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
75self.reset_parameters()
76else:
77self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
78self.merged = False
79elif not mode and not self.merged:
80warnings.warn("Invoke module.eval() would merge LoRA weights.")
81# Merge the weights and mark it
82if self.r > 0:
83self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
84delattr(self, "lora_A")
85delattr(self, "lora_B")
86self.merged = True
87
88return self
89
90def forward(self, x: torch.Tensor):
91def T(w):
92return w.T if self.fan_in_fan_out else w
93
94if self.r > 0 and not self.merged:
95result = F.linear(x, T(self.weight), bias=self.bias)
96if self.r > 0:
97result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
98return result
99else:
100return F.linear(x, T(self.weight), bias=self.bias)
101
102
103def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
104assert (
105lora_rank <= linear.in_features
106), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
107lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
108return lora_linear
109
110
111def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
112for name, child in module.named_children():
113if isinstance(child, nn.Linear):
114setattr(module, name, _lora_linear_wrapper(child, lora_rank))
115else:
116_convert_to_lora_recursively(child, lora_rank)
117
118
119def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
120"""Convert a torch.nn.Module to a LoRA module.
121
122Args:
123module (nn.Module): The module to convert.
124lora_rank (int): LoRA rank.
125
126Returns:
127nn.Module: The converted module.
128"""
129if lora_rank <= 0:
130return module
131_convert_to_lora_recursively(module, lora_rank)
132lora.mark_only_lora_as_trainable(module, lora_train_bias)
133return module
134
135
136class LoRAModule(nn.Module):
137"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
138This class will convert all torch.nn.Linear layer to LoraLinear layer.
139
140Args:
141lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
142lora_train_bias (str, optional): Whether LoRA train biases.
143'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
144Defaults to 'none'.
145"""
146
147def __init__(self, lora_rank: int = 0, lora_train_bias: str = "none") -> None:
148super().__init__()
149self.lora_rank = lora_rank
150self.lora_train_bias = lora_train_bias
151
152def convert_to_lora(self) -> None:
153convert_to_lora_module(self, self.lora_rank, self.lora_train_bias)
154