promptfoo

Форк
0
/
localai.ts 
175 строк · 4.5 Кб
1
import logger from '../logger';
2
import { fetchWithCache } from '../cache';
3
import { REQUEST_TIMEOUT_MS, parseChatPrompt } from './shared';
4

5
import type {
6
  ApiProvider,
7
  EnvOverrides,
8
  ProviderEmbeddingResponse,
9
  ProviderResponse,
10
} from '../types.js';
11

12
interface LocalAiCompletionOptions {
13
  apiBaseUrl?: string;
14
  temperature?: number;
15
}
16

17
class LocalAiGenericProvider implements ApiProvider {
18
  modelName: string;
19
  apiBaseUrl: string;
20
  config: LocalAiCompletionOptions;
21

22
  constructor(
23
    modelName: string,
24
    options: { config?: LocalAiCompletionOptions; id?: string; env?: EnvOverrides } = {},
25
  ) {
26
    const { id, config, env } = options;
27
    this.modelName = modelName;
28
    this.apiBaseUrl =
29
      config?.apiBaseUrl ||
30
      env?.LOCALAI_BASE_URL ||
31
      process.env.LOCALAI_BASE_URL ||
32
      'http://localhost:8080/v1';
33
    this.config = config || {};
34
    this.id = id ? () => id : this.id;
35
  }
36

37
  id(): string {
38
    return `localai:${this.modelName}`;
39
  }
40

41
  toString(): string {
42
    return `[LocalAI Provider ${this.modelName}]`;
43
  }
44

45
  // @ts-ignore: Prompt is not used in this implementation
46
  async callApi(prompt: string): Promise<ProviderResponse> {
47
    throw new Error('Not implemented');
48
  }
49
}
50

51
export class LocalAiChatProvider extends LocalAiGenericProvider {
52
  async callApi(prompt: string): Promise<ProviderResponse> {
53
    const messages = parseChatPrompt(prompt, [{ role: 'user', content: prompt }]);
54
    const body = {
55
      model: this.modelName,
56
      messages: messages,
57
      temperature: this.config.temperature || process.env.LOCALAI_TEMPERATURE || 0.7,
58
    };
59
    logger.debug(`Calling LocalAI API: ${JSON.stringify(body)}`);
60

61
    let data,
62
      cached = false;
63
    try {
64
      ({ data, cached } = (await fetchWithCache(
65
        `${this.apiBaseUrl}/chat/completions`,
66
        {
67
          method: 'POST',
68
          headers: {
69
            'Content-Type': 'application/json',
70
          },
71
          body: JSON.stringify(body),
72
        },
73
        REQUEST_TIMEOUT_MS,
74
      )) as unknown as any);
75
    } catch (err) {
76
      return {
77
        error: `API call error: ${String(err)}`,
78
      };
79
    }
80
    logger.debug(`\tLocalAI API chat completions response: ${JSON.stringify(data)}`);
81
    try {
82
      return {
83
        output: data.choices[0].message.content,
84
      };
85
    } catch (err) {
86
      return {
87
        error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
88
      };
89
    }
90
  }
91
}
92

93
export class LocalAiEmbeddingProvider extends LocalAiGenericProvider {
94
  async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
95
    const body = {
96
      input: text,
97
      model: this.modelName,
98
    };
99
    let data,
100
      cached = false;
101
    try {
102
      ({ data, cached } = (await fetchWithCache(
103
        `${this.apiBaseUrl}/embeddings`,
104
        {
105
          method: 'POST',
106
          headers: {
107
            'Content-Type': 'application/json',
108
          },
109
          body: JSON.stringify(body),
110
        },
111
        REQUEST_TIMEOUT_MS,
112
      )) as unknown as any);
113
    } catch (err) {
114
      return {
115
        error: `API call error: ${String(err)}`,
116
      };
117
    }
118
    logger.debug(`\tLocalAI embeddings API response: ${JSON.stringify(data)}`);
119

120
    try {
121
      const embedding = data?.data?.[0]?.embedding;
122
      if (!embedding) {
123
        throw new Error('No embedding found in LocalAI embeddings API response');
124
      }
125
      return {
126
        embedding,
127
      };
128
    } catch (err) {
129
      return {
130
        error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
131
      };
132
    }
133
  }
134
}
135

136
export class LocalAiCompletionProvider extends LocalAiGenericProvider {
137
  async callApi(prompt: string): Promise<ProviderResponse> {
138
    const body = {
139
      model: this.modelName,
140
      prompt,
141
      temperature: this.config.temperature || process.env.LOCALAI_TEMPERATURE || 0.7,
142
    };
143
    logger.debug(`Calling LocalAI API: ${JSON.stringify(body)}`);
144

145
    let data,
146
      cached = false;
147
    try {
148
      ({ data, cached } = (await fetchWithCache(
149
        `${this.apiBaseUrl}/completions`,
150
        {
151
          method: 'POST',
152
          headers: {
153
            'Content-Type': 'application/json',
154
          },
155
          body: JSON.stringify(body),
156
        },
157
        REQUEST_TIMEOUT_MS,
158
      )) as unknown as any);
159
    } catch (err) {
160
      return {
161
        error: `API call error: ${String(err)}`,
162
      };
163
    }
164
    logger.debug(`\tLocalAI completions API response: ${JSON.stringify(data)}`);
165
    try {
166
      return {
167
        output: data.choices[0].text,
168
      };
169
    } catch (err) {
170
      return {
171
        error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
172
      };
173
    }
174
  }
175
}
176

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.