llama-index

Форк
0
1
"""OpenAI Finetuning."""
2

3
import logging
4
import os
5
import time
6
from typing import Any, Optional
7

8
import openai
9
from openai import OpenAI as SyncOpenAI
10
from openai.types.fine_tuning import FineTuningJob
11

12
from llama_index.legacy.callbacks import OpenAIFineTuningHandler
13
from llama_index.legacy.finetuning.openai.validate_json import validate_json
14
from llama_index.legacy.finetuning.types import BaseLLMFinetuneEngine
15
from llama_index.legacy.llms import OpenAI
16
from llama_index.legacy.llms.llm import LLM
17

18
logger = logging.getLogger(__name__)
19

20

21
class OpenAIFinetuneEngine(BaseLLMFinetuneEngine):
22
    """OpenAI Finetuning Engine."""
23

24
    def __init__(
25
        self,
26
        base_model: str,
27
        data_path: str,
28
        verbose: bool = False,
29
        start_job_id: Optional[str] = None,
30
        validate_json: bool = True,
31
    ) -> None:
32
        """Init params."""
33
        self.base_model = base_model
34
        self.data_path = data_path
35
        self._verbose = verbose
36
        self._validate_json = validate_json
37
        self._start_job: Optional[Any] = None
38
        self._client = SyncOpenAI(api_key=os.getenv("OPENAI_API_KEY", None))
39
        if start_job_id is not None:
40
            self._start_job = self._client.fine_tuning.jobs.retrieve(start_job_id)
41

42
    @classmethod
43
    def from_finetuning_handler(
44
        cls,
45
        finetuning_handler: OpenAIFineTuningHandler,
46
        base_model: str,
47
        data_path: str,
48
        **kwargs: Any,
49
    ) -> "OpenAIFinetuneEngine":
50
        """Initialize from finetuning handler.
51

52
        Used to finetune an OpenAI model into another
53
        OpenAI model (e.g. gpt-3.5-turbo on top of GPT-4).
54

55
        """
56
        finetuning_handler.save_finetuning_events(data_path)
57
        return cls(base_model=base_model, data_path=data_path, **kwargs)
58

59
    def finetune(self) -> None:
60
        """Finetune model."""
61
        if self._validate_json:
62
            validate_json(self.data_path)
63

64
        # TODO: figure out how to specify file name in the new API
65
        # file_name = os.path.basename(self.data_path)
66

67
        # upload file
68
        with open(self.data_path, "rb") as f:
69
            output = self._client.files.create(file=f, purpose="fine-tune")
70
        logger.info("File uploaded...")
71
        if self._verbose:
72
            print("File uploaded...")
73

74
        # launch training
75
        while True:
76
            try:
77
                job_output = self._client.fine_tuning.jobs.create(
78
                    training_file=output.id, model=self.base_model
79
                )
80
                self._start_job = job_output
81
                break
82
            except openai.BadRequestError:
83
                print("Waiting for file to be ready...")
84
                time.sleep(60)
85
        info_str = (
86
            f"Training job {output.id} launched. "
87
            "You will be emailed when it's complete."
88
        )
89
        logger.info(info_str)
90
        if self._verbose:
91
            print(info_str)
92

93
    def get_current_job(self) -> FineTuningJob:
94
        """Get current job."""
95
        # validate that it works
96
        if not self._start_job:
97
            raise ValueError("Must call finetune() first")
98

99
        # try getting id, make sure that run succeeded
100
        job_id = self._start_job.id
101
        return self._client.fine_tuning.jobs.retrieve(job_id)
102

103
    def get_finetuned_model(self, **model_kwargs: Any) -> LLM:
104
        """Gets finetuned model."""
105
        current_job = self.get_current_job()
106

107
        job_id = current_job.id
108
        status = current_job.status
109
        model_id = current_job.fine_tuned_model
110

111
        if model_id is None:
112
            raise ValueError(
113
                f"Job {job_id} does not have a finetuned model id ready yet."
114
            )
115
        if status != "succeeded":
116
            raise ValueError(f"Job {job_id} has status {status}, cannot get model")
117

118
        return OpenAI(model=model_id, **model_kwargs)
119

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.