llama-index

Форк
0
174 строки · 6.3 Кб
1
"""Sentence Transformer Finetuning Engine."""
2

3
import logging
4
from typing import Any, List, Optional, Tuple, Type, cast
5

6
from llama_index.legacy.embeddings.adapter import AdapterEmbeddingModel
7
from llama_index.legacy.embeddings.base import BaseEmbedding
8
from llama_index.legacy.finetuning.embeddings.common import EmbeddingQAFinetuneDataset
9
from llama_index.legacy.finetuning.types import BaseEmbeddingFinetuneEngine
10
from llama_index.legacy.utils import infer_torch_device
11

12
logger = logging.getLogger(__name__)
13

14

15
class EmbeddingAdapterFinetuneEngine(BaseEmbeddingFinetuneEngine):
16
    """Embedding adapter finetune engine.
17

18
    Args:
19
        dataset (EmbeddingQAFinetuneDataset): Dataset to finetune on.
20
        embed_model (BaseEmbedding): Embedding model to finetune.
21
        batch_size (Optional[int]): Batch size. Defaults to 10.
22
        epochs (Optional[int]): Number of epochs. Defaults to 1.
23
        dim (Optional[int]): Dimension of embedding. Defaults to None.
24
        adapter_model (Optional[BaseAdapter]): Adapter model. Defaults to None, in which
25
            case a linear adapter is used.
26
        device (Optional[str]): Device to use. Defaults to None.
27
        model_output_path (str): Path to save model output. Defaults to "model_output".
28
        model_checkpoint_path (Optional[str]): Path to save model checkpoints.
29
            Defaults to None (don't save checkpoints).
30
        verbose (bool): Whether to show progress bar. Defaults to False.
31
        bias (bool): Whether to use bias. Defaults to False.
32

33
    """
34

35
    def __init__(
36
        self,
37
        dataset: EmbeddingQAFinetuneDataset,
38
        embed_model: BaseEmbedding,
39
        batch_size: int = 10,
40
        epochs: int = 1,
41
        adapter_model: Optional[Any] = None,
42
        dim: Optional[int] = None,
43
        device: Optional[str] = None,
44
        model_output_path: str = "model_output",
45
        model_checkpoint_path: Optional[str] = None,
46
        checkpoint_save_steps: int = 100,
47
        verbose: bool = False,
48
        bias: bool = False,
49
        **train_kwargs: Any,
50
    ) -> None:
51
        """Init params."""
52
        import torch
53

54
        from llama_index.legacy.embeddings.adapter_utils import BaseAdapter, LinearLayer
55

56
        self.dataset = dataset
57
        self.embed_model = embed_model
58

59
        # HACK: get dimension by passing text through it
60
        if dim is None:
61
            test_embedding = self.embed_model.get_text_embedding("hello world")
62
            self.dim = len(test_embedding)
63
        else:
64
            self.dim = dim
65

66
        # load in data, run embedding model, define data loader
67

68
        self.batch_size = batch_size
69
        self.loader = self._get_data_loader(dataset)
70

71
        if device is None:
72
            device = infer_torch_device()
73
            logger.info(f"Use pytorch device: {device}")
74
        self._target_device = torch.device(device)
75

76
        if adapter_model is not None:
77
            self.model = cast(BaseAdapter, adapter_model)
78
        else:
79
            self.model = LinearLayer(self.dim, self.dim, bias=bias)
80

81
        self._model_output_path = model_output_path
82
        self._model_checkpoint_path = model_checkpoint_path
83
        self._checkpoint_save_steps = checkpoint_save_steps
84
        self._epochs = epochs
85
        self._warmup_steps = int(len(self.loader) * epochs * 0.1)
86
        self._train_kwargs = train_kwargs
87

88
        self._verbose = verbose
89

90
    @classmethod
91
    def from_model_path(
92
        cls,
93
        dataset: EmbeddingQAFinetuneDataset,
94
        embed_model: BaseEmbedding,
95
        model_path: str,
96
        model_cls: Optional[Type[Any]] = None,
97
        **kwargs: Any,
98
    ) -> "EmbeddingAdapterFinetuneEngine":
99
        """Load from model path.
100

101
        Args:
102
            dataset (EmbeddingQAFinetuneDataset): Dataset to finetune on.
103
            embed_model (BaseEmbedding): Embedding model to finetune.
104
            model_path (str): Path to model.
105
            model_cls (Optional[Type[Any]]): Adapter model class. Defaults to None.
106
            **kwargs (Any): Additional kwargs (see __init__)
107

108
        """
109
        from llama_index.legacy.embeddings.adapter_utils import LinearLayer
110

111
        model_cls = model_cls or LinearLayer
112
        model = model_cls.load(model_path)
113
        return cls(dataset, embed_model, adapter_model=model, **kwargs)
114

115
    def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]:
116
        """Smart batching collate."""
117
        import torch
118
        from torch import Tensor
119

120
        query_embeddings: List[Tensor] = []
121
        text_embeddings: List[Tensor] = []
122

123
        for query, text in batch:
124
            query_embedding = self.embed_model.get_query_embedding(query)
125
            text_embedding = self.embed_model.get_text_embedding(text)
126

127
            query_embeddings.append(torch.tensor(query_embedding))
128
            text_embeddings.append(torch.tensor(text_embedding))
129

130
        query_embeddings_t = torch.stack(query_embeddings)
131
        text_embeddings_t = torch.stack(text_embeddings)
132

133
        return query_embeddings_t, text_embeddings_t
134

135
    def _get_data_loader(self, dataset: EmbeddingQAFinetuneDataset) -> Any:
136
        """Get data loader."""
137
        from torch.utils.data import DataLoader
138

139
        examples: Any = []
140

141
        for query_id, query in dataset.queries.items():
142
            node_id = dataset.relevant_docs[query_id][0]
143
            text = dataset.corpus[node_id]
144

145
            examples.append((query, text))
146

147
        data_loader = DataLoader(examples, batch_size=self.batch_size)
148
        data_loader.collate_fn = self.smart_batching_collate
149

150
        return data_loader
151

152
    def finetune(self, **train_kwargs: Any) -> None:
153
        """Finetune."""
154
        from llama_index.legacy.finetuning.embeddings.adapter_utils import train_model
155

156
        # call model training
157
        train_model(
158
            self.model,
159
            self.loader,
160
            self._target_device,
161
            epochs=self._epochs,
162
            output_path=self._model_output_path,
163
            warmup_steps=self._warmup_steps,
164
            verbose=self._verbose,
165
            checkpoint_path=self._model_checkpoint_path,
166
            checkpoint_save_steps=self._checkpoint_save_steps,
167
            **self._train_kwargs,
168
        )
169

170
    def get_finetuned_model(self, **model_kwargs: Any) -> BaseEmbedding:
171
        """Get finetuned model."""
172
        return AdapterEmbeddingModel(
173
            self.embed_model, self._model_output_path, **model_kwargs
174
        )
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.