llama-index

Форк
0
116 строк · 4.2 Кб
1
"""Embedding adapter model."""
2

3
import logging
4
from typing import Any, List, Optional, Type, cast
5

6
from llama_index.legacy.bridge.pydantic import PrivateAttr
7
from llama_index.legacy.callbacks import CallbackManager
8
from llama_index.legacy.constants import DEFAULT_EMBED_BATCH_SIZE
9
from llama_index.legacy.core.embeddings.base import BaseEmbedding
10
from llama_index.legacy.utils import infer_torch_device
11

12
logger = logging.getLogger(__name__)
13

14

15
class AdapterEmbeddingModel(BaseEmbedding):
16
    """Adapter for any embedding model.
17

18
    This is a wrapper around any embedding model that adds an adapter layer \
19
        on top of it.
20
    This is useful for finetuning an embedding model on a downstream task.
21
    The embedding model can be any model - it does not need to expose gradients.
22

23
    Args:
24
        base_embed_model (BaseEmbedding): Base embedding model.
25
        adapter_path (str): Path to adapter.
26
        adapter_cls (Optional[Type[Any]]): Adapter class. Defaults to None, in which \
27
            case a linear adapter is used.
28
        transform_query (bool): Whether to transform query embeddings. Defaults to True.
29
        device (Optional[str]): Device to use. Defaults to None.
30
        embed_batch_size (int): Batch size for embedding. Defaults to 10.
31
        callback_manager (Optional[CallbackManager]): Callback manager. \
32
            Defaults to None.
33

34
    """
35

36
    _base_embed_model: BaseEmbedding = PrivateAttr()
37
    _adapter: Any = PrivateAttr()
38
    _transform_query: bool = PrivateAttr()
39
    _device: Optional[str] = PrivateAttr()
40
    _target_device: Any = PrivateAttr()
41

42
    def __init__(
43
        self,
44
        base_embed_model: BaseEmbedding,
45
        adapter_path: str,
46
        adapter_cls: Optional[Type[Any]] = None,
47
        transform_query: bool = True,
48
        device: Optional[str] = None,
49
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
50
        callback_manager: Optional[CallbackManager] = None,
51
    ) -> None:
52
        """Init params."""
53
        import torch
54

55
        from llama_index.legacy.embeddings.adapter_utils import BaseAdapter, LinearLayer
56

57
        if device is None:
58
            device = infer_torch_device()
59
            logger.info(f"Use pytorch device: {device}")
60
        self._target_device = torch.device(device)
61

62
        self._base_embed_model = base_embed_model
63

64
        if adapter_cls is None:
65
            adapter_cls = LinearLayer
66
        else:
67
            adapter_cls = cast(Type[BaseAdapter], adapter_cls)
68

69
        adapter = adapter_cls.load(adapter_path)
70
        self._adapter = cast(BaseAdapter, adapter)
71
        self._adapter.to(self._target_device)
72

73
        self._transform_query = transform_query
74
        super().__init__(
75
            embed_batch_size=embed_batch_size,
76
            callback_manager=callback_manager,
77
            model_name=f"Adapter for {base_embed_model.model_name}",
78
        )
79

80
    @classmethod
81
    def class_name(cls) -> str:
82
        return "AdapterEmbeddingModel"
83

84
    def _get_query_embedding(self, query: str) -> List[float]:
85
        """Get query embedding."""
86
        import torch
87

88
        query_embedding = self._base_embed_model._get_query_embedding(query)
89
        if self._transform_query:
90
            query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
91
            query_embedding_t = self._adapter.forward(query_embedding_t)
92
            query_embedding = query_embedding_t.tolist()
93

94
        return query_embedding
95

96
    async def _aget_query_embedding(self, query: str) -> List[float]:
97
        """Get query embedding."""
98
        import torch
99

100
        query_embedding = await self._base_embed_model._aget_query_embedding(query)
101
        if self._transform_query:
102
            query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
103
            query_embedding_t = self._adapter.forward(query_embedding_t)
104
            query_embedding = query_embedding_t.tolist()
105

106
        return query_embedding
107

108
    def _get_text_embedding(self, text: str) -> List[float]:
109
        return self._base_embed_model._get_text_embedding(text)
110

111
    async def _aget_text_embedding(self, text: str) -> List[float]:
112
        return await self._base_embed_model._aget_text_embedding(text)
113

114

115
# Maintain for backwards compatibility
116
LinearAdapterEmbeddingModel = AdapterEmbeddingModel
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.