llama-index

Форк
0
1
"""Gradient Finetuning."""
2

3
import json
4
from typing import Any, Optional, overload
5

6
from llama_index.legacy.finetuning.types import BaseLLMFinetuneEngine
7
from llama_index.legacy.llms.gradient import GradientModelAdapterLLM
8

9

10
class GradientFinetuneEngine(BaseLLMFinetuneEngine):
11
    @overload
12
    def __init__(
13
        self,
14
        *,
15
        access_token: Optional[str] = None,
16
        base_model_slug: str,
17
        data_path: str,
18
        host: Optional[str] = None,
19
        learning_rate: Optional[float] = None,
20
        name: str,
21
        rank: Optional[int] = None,
22
        workspace_id: Optional[str] = None,
23
    ) -> None:
24
        ...
25

26
    @overload
27
    def __init__(
28
        self,
29
        *,
30
        access_token: Optional[str] = None,
31
        data_path: str,
32
        host: Optional[str] = None,
33
        model_adapter_id: str,
34
        workspace_id: Optional[str] = None,
35
    ) -> None:
36
        ...
37

38
    def __init__(
39
        self,
40
        *,
41
        access_token: Optional[str] = None,
42
        base_model_slug: Optional[str] = None,
43
        data_path: str,
44
        host: Optional[str] = None,
45
        learning_rate: Optional[float] = None,
46
        model_adapter_id: Optional[str] = None,
47
        name: Optional[str] = None,
48
        rank: Optional[int] = None,
49
        workspace_id: Optional[str] = None,
50
        verbose: bool = True,
51
        max_steps: Optional[int] = None,
52
        batch_size: int = 1,
53
    ) -> None:
54
        self._access_token = access_token
55
        self._host = host
56
        self._workspace_id = workspace_id
57
        self._data_path = data_path
58
        self._max_steps = max_steps
59
        self._batch_size = batch_size
60

61
        if (base_model_slug is None and model_adapter_id is None) or (
62
            isinstance(base_model_slug, str) and isinstance(model_adapter_id, str)
63
        ):
64
            raise ValueError(
65
                "expected be provided exactly one of base_model_slug or model_adapter_id"
66
            )
67
        try:
68
            from gradientai import Gradient
69

70
            self._gradient = Gradient(
71
                access_token=access_token, host=host, workspace_id=workspace_id
72
            )
73
            if isinstance(base_model_slug, str):
74
                if name is None:
75
                    raise ValueError("name must be provided with a base_model_slug")
76
                self._model_adapter = self._gradient.get_base_model(
77
                    base_model_slug=base_model_slug
78
                ).create_model_adapter(
79
                    name=name, rank=rank, learning_rate=learning_rate
80
                )
81
            if isinstance(model_adapter_id, str):
82
                self._model_adapter = self._gradient.get_model_adapter(
83
                    model_adapter_id=model_adapter_id
84
                )
85

86
        except ImportError as e:
87
            raise ImportError(
88
                "Could not import Gradient Python package. "
89
                "Please install it with `pip install gradientai`."
90
            ) from e
91
        self._verbose = verbose
92

93
    def close(self) -> None:
94
        self._gradient.close()
95

96
    def finetune(self) -> None:
97
        from gradientai import Sample
98

99
        cur_batch = []
100
        with open(self._data_path) as f:
101
            for [i, line] in enumerate(f):
102
                if self._max_steps is not None and i >= self._max_steps:
103
                    break
104
                parsedLine = json.loads(line)
105
                if not isinstance(parsedLine, dict):
106
                    raise ValueError(
107
                        f"each line should be a json object. line {i + 1} does not parse correctly"
108
                    )
109
                sample = Sample(
110
                    inputs=parsedLine["inputs"],
111
                    multiplier=parsedLine.get("multiplier", None),
112
                )
113
                cur_batch.append(sample)
114
                if len(cur_batch) == self._batch_size:
115
                    ft_response = self._model_adapter.fine_tune(samples=cur_batch)
116
                    cur_batch = []
117
                else:
118
                    ft_response = None
119

120
                if self._verbose and ft_response is not None:
121
                    print(
122
                        f"fine-tuning step {i + 1}: loss={ft_response.sum_loss}, "
123
                        f"trainable tokens={ft_response.number_of_trainable_tokens}"
124
                    )
125

126
        if len(cur_batch) > 0:
127
            ft_response = self._model_adapter.fine_tune(samples=cur_batch)
128
            cur_batch = []
129

130
    @property
131
    def model_adapter_id(self) -> str:
132
        return self._model_adapter.id
133

134
    @property
135
    def model_adapter(self) -> Any:
136
        return self._model_adapter
137

138
    def get_finetuned_model(self, **model_kwargs: Any) -> GradientModelAdapterLLM:
139
        return GradientModelAdapterLLM(
140
            access_token=self._access_token,
141
            host=self._host,
142
            model_adapter_id=self._model_adapter.id,
143
            workspace_id=self._workspace_id,
144
            **model_kwargs,
145
        )
146

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.