llama-index
179 строк · 5.4 Кб
1"""Adapter utils."""
2
3import json
4import logging
5import os
6from abc import abstractmethod
7from typing import Callable, Dict
8
9import torch
10import torch.nn.functional as F
11from torch import Tensor, nn
12
13logger = logging.getLogger(__name__)
14
15
16class BaseAdapter(nn.Module):
17"""Base adapter.
18
19Can be subclassed to implement custom adapters.
20To implement a custom adapter, subclass this class and implement the
21following methods:
22- get_config_dict
23- forward
24
25"""
26
27@abstractmethod
28def get_config_dict(self) -> Dict:
29"""Get config dict."""
30
31@abstractmethod
32def forward(self, embed: Tensor) -> Tensor:
33"""Forward pass."""
34
35def save(self, output_path: str) -> None:
36"""Save model."""
37os.makedirs(output_path, exist_ok=True)
38with open(os.path.join(output_path, "config.json"), "w") as fOut:
39json.dump(self.get_config_dict(), fOut)
40torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
41
42@classmethod
43def load(cls, input_path: str) -> "BaseAdapter":
44"""Load model."""
45with open(os.path.join(input_path, "config.json")) as fIn:
46config = json.load(fIn)
47model = cls(**config)
48model.load_state_dict(
49torch.load(
50os.path.join(input_path, "pytorch_model.bin"),
51map_location=torch.device("cpu"),
52)
53)
54return model
55
56
57class LinearLayer(BaseAdapter):
58"""Linear transformation.
59
60Args:
61in_features (int): Input dimension.
62out_features (int): Output dimension.
63bias (bool): Whether to use bias. Defaults to False.
64
65"""
66
67def __init__(self, in_features: int, out_features: int, bias: bool = False) -> None:
68super().__init__()
69self.in_features = in_features
70self.out_features = out_features
71self.bias = bias
72self.linear = nn.Linear(in_features, out_features, bias=bias)
73# seed with identity matrix and 0 bias
74# only works for square matrices
75self.linear.weight.data.copy_(torch.eye(in_features, out_features))
76if bias:
77self.linear.bias.data.copy_(torch.zeros(out_features))
78
79def forward(self, embed: Tensor) -> Tensor:
80"""Forward pass (Wv)."""
81return self.linear(embed)
82
83def get_config_dict(self) -> Dict:
84return {
85"in_features": self.in_features,
86"out_features": self.out_features,
87"bias": self.bias,
88}
89
90
91def get_activation_function(name: str) -> Callable:
92"""Get activation function.
93
94Args:
95name (str): Name of activation function.
96
97"""
98activations: Dict[str, Callable] = {
99"relu": F.relu,
100"sigmoid": torch.sigmoid,
101"tanh": torch.tanh,
102"leaky_relu": F.leaky_relu,
103# add more activations here as needed
104}
105if name not in activations:
106raise ValueError(f"Unknown activation function: {name}")
107return activations[name]
108
109
110class TwoLayerNN(BaseAdapter):
111"""Two-layer transformation.
112
113Args:
114in_features (int): Input dimension.
115hidden_features (int): Hidden dimension.
116out_features (int): Output dimension.
117bias (bool): Whether to use bias. Defaults to False.
118activation_fn_str (str): Name of activation function. Defaults to "relu".
119
120"""
121
122def __init__(
123self,
124in_features: int,
125hidden_features: int,
126out_features: int,
127bias: bool = False,
128activation_fn_str: str = "relu",
129add_residual: bool = False,
130) -> None:
131super().__init__()
132self.in_features = in_features
133self.hidden_features = hidden_features
134self.out_features = out_features
135self.bias = bias
136self.activation_fn_str = activation_fn_str
137
138self.linear1 = nn.Linear(in_features, hidden_features, bias=True)
139self.linear2 = nn.Linear(hidden_features, out_features, bias=True)
140# self.linear1.weight.data.copy_(torch.zeros(hidden_features, in_features))
141# self.linear2.weight.data.copy_(torch.zeros(out_features, hidden_features))
142# if bias:
143# self.linear1.bias.data.copy_(torch.zeros(hidden_features))
144# self.linear2.bias.data.copy_(torch.zeros(out_features))
145
146self._activation_function = get_activation_function(activation_fn_str)
147self._add_residual = add_residual
148# if add_residual, then add residual_weight (init to 0)
149self.residual_weight = nn.Parameter(torch.zeros(1))
150
151def forward(self, embed: Tensor) -> Tensor:
152"""Forward pass (Wv).
153
154Args:
155embed (Tensor): Input tensor.
156
157"""
158output1 = self.linear1(embed)
159output1 = self._activation_function(output1)
160output2 = self.linear2(output1)
161
162if self._add_residual:
163# print(output2)
164# print(self.residual_weight)
165# print(self.linear2.weight.data)
166output2 = self.residual_weight * output2 + embed
167
168return output2
169
170def get_config_dict(self) -> Dict:
171"""Get config dict."""
172return {
173"in_features": self.in_features,
174"hidden_features": self.hidden_features,
175"out_features": self.out_features,
176"bias": self.bias,
177"activation_fn_str": self.activation_fn_str,
178"add_residual": self._add_residual,
179}
180