llama-index

Форк
0
78 строк · 2.7 Кб
1
"""Cohere Reranker Finetuning Engine."""
2

3
import importlib.util
4
import os
5
from typing import Optional
6

7
from llama_index.legacy.finetuning.types import BaseCohereRerankerFinetuningEngine
8
from llama_index.legacy.indices.postprocessor import CohereRerank
9

10

11
class CohereRerankerFinetuneEngine(BaseCohereRerankerFinetuningEngine):
12
    """Cohere Reranker Finetune Engine."""
13

14
    def __init__(
15
        self,
16
        train_file_name: str = "train.jsonl",
17
        val_file_name: Optional[str] = None,
18
        model_name: str = "exp_finetune",
19
        model_type: str = "RERANK",
20
        base_model: str = "english",
21
        api_key: Optional[str] = None,
22
    ) -> None:
23
        """Init params."""
24
        # This will be None if 'cohere' module is not available
25
        cohere_spec = importlib.util.find_spec("cohere")
26

27
        if cohere_spec is not None:
28
            import cohere
29
        else:
30
            # Raise an ImportError if 'cohere' is not installed
31
            raise ImportError(
32
                "Cannot import cohere. Please install the package using `pip install cohere`."
33
            )
34

35
        try:
36
            self.api_key = api_key or os.environ["COHERE_API_KEY"]
37
        except IndexError:
38
            raise ValueError(
39
                "Must pass in cohere api key or "
40
                "specify via COHERE_API_KEY environment variable "
41
            )
42
        self._model = cohere.Client(self.api_key, client_name="llama_index")
43
        self._train_file_name = train_file_name
44
        self._val_file_name = val_file_name
45
        self._model_name = model_name
46
        self._model_type = model_type
47
        self._base_model = base_model
48
        self._finetune_model = None
49

50
    def finetune(self) -> None:
51
        """Finetune model."""
52
        from cohere.custom_model_dataset import JsonlDataset
53

54
        if self._val_file_name:
55
            # Uploading both train file and eval file
56
            dataset = JsonlDataset(
57
                train_file=self._train_file_name, eval_file=self._val_file_name
58
            )
59
        else:
60
            # Single Train File Upload:
61
            dataset = JsonlDataset(train_file=self._train_file_name)
62

63
        self._finetune_model = self._model.create_custom_model(
64
            name=self._model_name,
65
            dataset=dataset,
66
            model_type=self._model_type,
67
            base_model=self._base_model,
68
        )
69

70
    def get_finetuned_model(self, top_n: int = 5) -> CohereRerank:
71
        """Gets finetuned model id."""
72
        if self._finetune_model is None:
73
            raise RuntimeError(
74
                "Finetuned model is not set yet. Please run the finetune method first."
75
            )
76
        return CohereRerank(
77
            model=self._finetune_model.id, top_n=top_n, api_key=self.api_key
78
        )
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.