llama-index

Форк
0
179 строк · 5.4 Кб
1
"""Adapter utils."""
2

3
import json
4
import logging
5
import os
6
from abc import abstractmethod
7
from typing import Callable, Dict
8

9
import torch
10
import torch.nn.functional as F
11
from torch import Tensor, nn
12

13
logger = logging.getLogger(__name__)
14

15

16
class BaseAdapter(nn.Module):
17
    """Base adapter.
18

19
    Can be subclassed to implement custom adapters.
20
    To implement a custom adapter, subclass this class and implement the
21
    following methods:
22
        - get_config_dict
23
        - forward
24

25
    """
26

27
    @abstractmethod
28
    def get_config_dict(self) -> Dict:
29
        """Get config dict."""
30

31
    @abstractmethod
32
    def forward(self, embed: Tensor) -> Tensor:
33
        """Forward pass."""
34

35
    def save(self, output_path: str) -> None:
36
        """Save model."""
37
        os.makedirs(output_path, exist_ok=True)
38
        with open(os.path.join(output_path, "config.json"), "w") as fOut:
39
            json.dump(self.get_config_dict(), fOut)
40
        torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
41

42
    @classmethod
43
    def load(cls, input_path: str) -> "BaseAdapter":
44
        """Load model."""
45
        with open(os.path.join(input_path, "config.json")) as fIn:
46
            config = json.load(fIn)
47
        model = cls(**config)
48
        model.load_state_dict(
49
            torch.load(
50
                os.path.join(input_path, "pytorch_model.bin"),
51
                map_location=torch.device("cpu"),
52
            )
53
        )
54
        return model
55

56

57
class LinearLayer(BaseAdapter):
58
    """Linear transformation.
59

60
    Args:
61
        in_features (int): Input dimension.
62
        out_features (int): Output dimension.
63
        bias (bool): Whether to use bias. Defaults to False.
64

65
    """
66

67
    def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
68
        super().__init__()
69
        self.in_features = in_features
70
        self.out_features = out_features
71
        self.bias = bias
72
        self.linear = nn.Linear(in_features, out_features, bias=bias)
73
        # seed with identity matrix and 0 bias
74
        # only works for square matrices
75
        self.linear.weight.data.copy_(torch.eye(in_features, out_features))
76
        if bias:
77
            self.linear.bias.data.copy_(torch.zeros(out_features))
78

79
    def forward(self, embed: Tensor) -> Tensor:
80
        """Forward pass (Wv)."""
81
        return self.linear(embed)
82

83
    def get_config_dict(self) -> Dict:
84
        return {
85
            "in_features": self.in_features,
86
            "out_features": self.out_features,
87
            "bias": self.bias,
88
        }
89

90

91
def get_activation_function(name: str) -> Callable:
92
    """Get activation function.
93

94
    Args:
95
        name (str): Name of activation function.
96

97
    """
98
    activations: Dict[str, Callable] = {
99
        "relu": F.relu,
100
        "sigmoid": torch.sigmoid,
101
        "tanh": torch.tanh,
102
        "leaky_relu": F.leaky_relu,
103
        # add more activations here as needed
104
    }
105
    if name not in activations:
106
        raise ValueError(f"Unknown activation function: {name}")
107
    return activations[name]
108

109

110
class TwoLayerNN(BaseAdapter):
111
    """Two-layer transformation.
112

113
    Args:
114
        in_features (int): Input dimension.
115
        hidden_features (int): Hidden dimension.
116
        out_features (int): Output dimension.
117
        bias (bool): Whether to use bias. Defaults to False.
118
        activation_fn_str (str): Name of activation function. Defaults to "relu".
119

120
    """
121

122
    def __init__(
123
        self,
124
        in_features: int,
125
        hidden_features: int,
126
        out_features: int,
127
        bias: bool = False,
128
        activation_fn_str: str = "relu",
129
        add_residual: bool = False,
130
    ) -> None:
131
        super().__init__()
132
        self.in_features = in_features
133
        self.hidden_features = hidden_features
134
        self.out_features = out_features
135
        self.bias = bias
136
        self.activation_fn_str = activation_fn_str
137

138
        self.linear1 = nn.Linear(in_features, hidden_features, bias=True)
139
        self.linear2 = nn.Linear(hidden_features, out_features, bias=True)
140
        # self.linear1.weight.data.copy_(torch.zeros(hidden_features, in_features))
141
        # self.linear2.weight.data.copy_(torch.zeros(out_features, hidden_features))
142
        # if bias:
143
        #     self.linear1.bias.data.copy_(torch.zeros(hidden_features))
144
        #     self.linear2.bias.data.copy_(torch.zeros(out_features))
145

146
        self._activation_function = get_activation_function(activation_fn_str)
147
        self._add_residual = add_residual
148
        # if add_residual, then add residual_weight (init to 0)
149
        self.residual_weight = nn.Parameter(torch.zeros(1))
150

151
    def forward(self, embed: Tensor) -> Tensor:
152
        """Forward pass (Wv).
153

154
        Args:
155
            embed (Tensor): Input tensor.
156

157
        """
158
        output1 = self.linear1(embed)
159
        output1 = self._activation_function(output1)
160
        output2 = self.linear2(output1)
161

162
        if self._add_residual:
163
            # print(output2)
164
            # print(self.residual_weight)
165
            # print(self.linear2.weight.data)
166
            output2 = self.residual_weight * output2 + embed
167

168
        return output2
169

170
    def get_config_dict(self) -> Dict:
171
        """Get config dict."""
172
        return {
173
            "in_features": self.in_features,
174
            "hidden_features": self.hidden_features,
175
            "out_features": self.out_features,
176
            "bias": self.bias,
177
            "activation_fn_str": self.activation_fn_str,
178
            "add_residual": self._add_residual,
179
        }
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.