llama-index
78 строк · 2.7 Кб
1"""Cohere Reranker Finetuning Engine."""
2
3import importlib.util
4import os
5from typing import Optional
6
7from llama_index.legacy.finetuning.types import BaseCohereRerankerFinetuningEngine
8from llama_index.legacy.indices.postprocessor import CohereRerank
9
10
11class CohereRerankerFinetuneEngine(BaseCohereRerankerFinetuningEngine):
12"""Cohere Reranker Finetune Engine."""
13
14def __init__(
15self,
16train_file_name: str = "train.jsonl",
17val_file_name: Optional[str] = None,
18model_name: str = "exp_finetune",
19model_type: str = "RERANK",
20base_model: str = "english",
21api_key: Optional[str] = None,
22) -> None:
23"""Init params."""
24# This will be None if 'cohere' module is not available
25cohere_spec = importlib.util.find_spec("cohere")
26
27if cohere_spec is not None:
28import cohere
29else:
30# Raise an ImportError if 'cohere' is not installed
31raise ImportError(
32"Cannot import cohere. Please install the package using `pip install cohere`."
33)
34
35try:
36self.api_key = api_key or os.environ["COHERE_API_KEY"]
37except IndexError:
38raise ValueError(
39"Must pass in cohere api key or "
40"specify via COHERE_API_KEY environment variable "
41)
42self._model = cohere.Client(self.api_key, client_name="llama_index")
43self._train_file_name = train_file_name
44self._val_file_name = val_file_name
45self._model_name = model_name
46self._model_type = model_type
47self._base_model = base_model
48self._finetune_model = None
49
50def finetune(self) -> None:
51"""Finetune model."""
52from cohere.custom_model_dataset import JsonlDataset
53
54if self._val_file_name:
55# Uploading both train file and eval file
56dataset = JsonlDataset(
57train_file=self._train_file_name, eval_file=self._val_file_name
58)
59else:
60# Single Train File Upload:
61dataset = JsonlDataset(train_file=self._train_file_name)
62
63self._finetune_model = self._model.create_custom_model(
64name=self._model_name,
65dataset=dataset,
66model_type=self._model_type,
67base_model=self._base_model,
68)
69
70def get_finetuned_model(self, top_n: int = 5) -> CohereRerank:
71"""Gets finetuned model id."""
72if self._finetune_model is None:
73raise RuntimeError(
74"Finetuned model is not set yet. Please run the finetune method first."
75)
76return CohereRerank(
77model=self._finetune_model.id, top_n=top_n, api_key=self.api_key
78)
79