ray-llm

Форк
0
/
sdk.py 
191 строка · 5.3 Кб
1
import os
2
import warnings
3
from typing import Any, Dict, Iterator, List, Optional, Union
4

5
import openai
6

7
from rayllm.common.models import ChatCompletion, Completion, Model
8
from rayllm.common.utils import (
9
    _get_langchain_model,
10
    _is_aviary_model,
11
    assert_has_backend,
12
)
13

14
__all__ = [
15
    "Model",
16
    "Completion",
17
    "ChatCompletion",
18
    "models",
19
    "metadata",
20
    "completions",
21
    "run",
22
    "get_aviary_backend",
23
    "stream",
24
]
25

26

27
class AviaryResource:
28
    """Stores information about the Aviary backend configuration."""
29

30
    def __init__(self, backend_url: str, token: str):
31
        assert "::param" not in backend_url, "backend_url not set correctly"
32
        assert "::param" not in token, "token not set correctly"
33

34
        self.backend_url = backend_url
35
        self.token = token
36
        self.bearer = f"Bearer {token}" if token else ""
37

38

39
class URLNotSetException(Exception):
40
    pass
41

42

43
def get_aviary_backend(verbose: Optional[bool] = None):
44
    """
45
    Establishes a connection to the Aviary backed after establishing
46
    the information using environmental variables.
47

48
    For direct connection to the aviary backend (e.g. running on the same cluster),
49
    no AVIARY_TOKEN is required. Otherwise, the AVIARY_URL and AVIARY_TOKEN environment
50
    variables are required.
51

52
    Args:
53
        verbose: Whether to print the connecting message.
54

55
    Returns:
56
        An instance of the AviaryResource class.
57
    """
58
    aviary_url = os.getenv("AVIARY_URL", os.getenv("OPENAI_API_BASE"))
59
    if not aviary_url:
60
        raise URLNotSetException("AVIARY_URL or OPENAI_API_BASE must be set")
61

62
    aviary_token = os.getenv("AVIARY_TOKEN", os.getenv("OPENAI_API_KEY")) or ""
63

64
    aviary_url += "/v1" if not aviary_url.endswith("/v1") else ""
65

66
    if verbose is None:
67
        verbose = os.environ.get("AVIARY_SILENT", "0") == "0"
68
    if verbose:
69
        print(f"Connecting to Aviary backend at: {aviary_url}")
70
    return AviaryResource(aviary_url, aviary_token)
71

72

73
def get_openai_client() -> openai.Client:
74
    """Get an OpenAI Client connected to the ray-llm backend."""
75
    backend = get_aviary_backend()
76
    openai_client = openai.Client(base_url=backend.backend_url, api_key=backend.token)
77
    return openai_client
78

79

80
def models() -> List[str]:
81
    """List available models"""
82
    models = get_openai_client().models.list()
83
    return [model.id for model in models.data]
84

85

86
def metadata(model_id: str) -> Dict[str, Dict[str, Any]]:
87
    """Get model metadata"""
88
    metadata = get_openai_client().models.retrieve(model_id).model_dump()
89
    return metadata
90

91

92
def completions(
93
    model: str,
94
    prompt: str,
95
    use_prompt_format: bool = True,
96
    **kwargs,
97
) -> Dict[str, Union[str, float, int]]:
98
    """Get completions from Aviary models."""
99
    kwargs.setdefault("max_tokens", None)
100
    if _is_aviary_model(model):
101
        if use_prompt_format:
102
            result = get_openai_client().chat.completions.create(
103
                model=model,
104
                messages=[{"role": "user", "content": prompt}],
105
                stream=False,
106
                **kwargs,
107
            )
108
        else:
109
            result = get_openai_client().completions.create(
110
                model=model,
111
                prompt=prompt,
112
                stream=False,
113
                **kwargs,
114
            )
115
        return result.model_dump()
116
    llm = _get_langchain_model(model)
117
    return llm.predict(prompt)
118

119

120
def query(
121
    model: str,
122
    prompt: str,
123
    use_prompt_format: bool = True,
124
    **kwargs,
125
) -> Dict[str, Union[str, float, int]]:
126
    warnings.warn(
127
        "'query' is deprecated, please use 'completions' instead",
128
        DeprecationWarning,
129
        stacklevel=2,
130
    )
131
    return completions(model, prompt, use_prompt_format, **kwargs)
132

133

134
def _iterator(gen):
135
    for x in gen:
136
        yield x.model_dump()
137

138

139
def stream(
140
    model: str,
141
    prompt: str,
142
    use_prompt_format: bool = True,
143
    **kwargs,
144
) -> Iterator[Dict[str, Union[str, float, int]]]:
145
    """Query Aviary and stream response"""
146
    kwargs.setdefault("max_tokens", None)
147
    if _is_aviary_model(model):
148
        if use_prompt_format:
149
            result = get_openai_client().chat.completions.create(
150
                model=model,
151
                messages=[{"role": "user", "content": prompt}],
152
                stream=True,
153
                **kwargs,
154
            )
155
        else:
156
            result = get_openai_client().completions.create(
157
                model=model,
158
                prompt=prompt,
159
                stream=True,
160
                **kwargs,
161
            )
162
        return _iterator(result)
163
    else:
164
        # TODO implement streaming for langchain models
165
        raise RuntimeError("Streaming is currently only supported for aviary models")
166

167

168
def run(*models: List[str], blocking: bool = True) -> None:
169
    """Run Aviary on the local ray cluster
170

171
    args:
172
        *models: Models to run.
173
        blocking: Whether to block the CLI until the application is ready.
174

175
    NOTE: This only works if you are running this command
176
    on the Ray or Anyscale cluster directly. It does not
177
    work from a general machine which only has the url and token
178
    for a model.
179
    """
180
    assert_has_backend()
181
    from rayllm.backend.server.run import run
182

183
    run(*models, blocking=blocking)
184

185

186
def shutdown() -> None:
187
    """Shutdown the Aviary backend server"""
188
    assert_has_backend()
189
    from ray import serve
190

191
    serve.shutdown()
192

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.