llama-index
174 строки · 6.3 Кб
1"""Sentence Transformer Finetuning Engine."""
2
3import logging4from typing import Any, List, Optional, Tuple, Type, cast5
6from llama_index.legacy.embeddings.adapter import AdapterEmbeddingModel7from llama_index.legacy.embeddings.base import BaseEmbedding8from llama_index.legacy.finetuning.embeddings.common import EmbeddingQAFinetuneDataset9from llama_index.legacy.finetuning.types import BaseEmbeddingFinetuneEngine10from llama_index.legacy.utils import infer_torch_device11
12logger = logging.getLogger(__name__)13
14
15class EmbeddingAdapterFinetuneEngine(BaseEmbeddingFinetuneEngine):16"""Embedding adapter finetune engine.17
18Args:
19dataset (EmbeddingQAFinetuneDataset): Dataset to finetune on.
20embed_model (BaseEmbedding): Embedding model to finetune.
21batch_size (Optional[int]): Batch size. Defaults to 10.
22epochs (Optional[int]): Number of epochs. Defaults to 1.
23dim (Optional[int]): Dimension of embedding. Defaults to None.
24adapter_model (Optional[BaseAdapter]): Adapter model. Defaults to None, in which
25case a linear adapter is used.
26device (Optional[str]): Device to use. Defaults to None.
27model_output_path (str): Path to save model output. Defaults to "model_output".
28model_checkpoint_path (Optional[str]): Path to save model checkpoints.
29Defaults to None (don't save checkpoints).
30verbose (bool): Whether to show progress bar. Defaults to False.
31bias (bool): Whether to use bias. Defaults to False.
32
33"""
34
35def __init__(36self,37dataset: EmbeddingQAFinetuneDataset,38embed_model: BaseEmbedding,39batch_size: int = 10,40epochs: int = 1,41adapter_model: Optional[Any] = None,42dim: Optional[int] = None,43device: Optional[str] = None,44model_output_path: str = "model_output",45model_checkpoint_path: Optional[str] = None,46checkpoint_save_steps: int = 100,47verbose: bool = False,48bias: bool = False,49**train_kwargs: Any,50) -> None:51"""Init params."""52import torch53
54from llama_index.legacy.embeddings.adapter_utils import BaseAdapter, LinearLayer55
56self.dataset = dataset57self.embed_model = embed_model58
59# HACK: get dimension by passing text through it60if dim is None:61test_embedding = self.embed_model.get_text_embedding("hello world")62self.dim = len(test_embedding)63else:64self.dim = dim65
66# load in data, run embedding model, define data loader67
68self.batch_size = batch_size69self.loader = self._get_data_loader(dataset)70
71if device is None:72device = infer_torch_device()73logger.info(f"Use pytorch device: {device}")74self._target_device = torch.device(device)75
76if adapter_model is not None:77self.model = cast(BaseAdapter, adapter_model)78else:79self.model = LinearLayer(self.dim, self.dim, bias=bias)80
81self._model_output_path = model_output_path82self._model_checkpoint_path = model_checkpoint_path83self._checkpoint_save_steps = checkpoint_save_steps84self._epochs = epochs85self._warmup_steps = int(len(self.loader) * epochs * 0.1)86self._train_kwargs = train_kwargs87
88self._verbose = verbose89
90@classmethod91def from_model_path(92cls,93dataset: EmbeddingQAFinetuneDataset,94embed_model: BaseEmbedding,95model_path: str,96model_cls: Optional[Type[Any]] = None,97**kwargs: Any,98) -> "EmbeddingAdapterFinetuneEngine":99"""Load from model path.100
101Args:
102dataset (EmbeddingQAFinetuneDataset): Dataset to finetune on.
103embed_model (BaseEmbedding): Embedding model to finetune.
104model_path (str): Path to model.
105model_cls (Optional[Type[Any]]): Adapter model class. Defaults to None.
106**kwargs (Any): Additional kwargs (see __init__)
107
108"""
109from llama_index.legacy.embeddings.adapter_utils import LinearLayer110
111model_cls = model_cls or LinearLayer112model = model_cls.load(model_path)113return cls(dataset, embed_model, adapter_model=model, **kwargs)114
115def smart_batching_collate(self, batch: List) -> Tuple[Any, Any]:116"""Smart batching collate."""117import torch118from torch import Tensor119
120query_embeddings: List[Tensor] = []121text_embeddings: List[Tensor] = []122
123for query, text in batch:124query_embedding = self.embed_model.get_query_embedding(query)125text_embedding = self.embed_model.get_text_embedding(text)126
127query_embeddings.append(torch.tensor(query_embedding))128text_embeddings.append(torch.tensor(text_embedding))129
130query_embeddings_t = torch.stack(query_embeddings)131text_embeddings_t = torch.stack(text_embeddings)132
133return query_embeddings_t, text_embeddings_t134
135def _get_data_loader(self, dataset: EmbeddingQAFinetuneDataset) -> Any:136"""Get data loader."""137from torch.utils.data import DataLoader138
139examples: Any = []140
141for query_id, query in dataset.queries.items():142node_id = dataset.relevant_docs[query_id][0]143text = dataset.corpus[node_id]144
145examples.append((query, text))146
147data_loader = DataLoader(examples, batch_size=self.batch_size)148data_loader.collate_fn = self.smart_batching_collate149
150return data_loader151
152def finetune(self, **train_kwargs: Any) -> None:153"""Finetune."""154from llama_index.legacy.finetuning.embeddings.adapter_utils import train_model155
156# call model training157train_model(158self.model,159self.loader,160self._target_device,161epochs=self._epochs,162output_path=self._model_output_path,163warmup_steps=self._warmup_steps,164verbose=self._verbose,165checkpoint_path=self._model_checkpoint_path,166checkpoint_save_steps=self._checkpoint_save_steps,167**self._train_kwargs,168)169
170def get_finetuned_model(self, **model_kwargs: Any) -> BaseEmbedding:171"""Get finetuned model."""172return AdapterEmbeddingModel(173self.embed_model, self._model_output_path, **model_kwargs174)175