promptfoo

Форк
0
/
vertex.ts 
378 строк · 10.8 Кб
1
import logger from '../logger';
2
import { fetchWithCache } from '../cache';
3
import { parseChatPrompt, REQUEST_TIMEOUT_MS } from './shared';
4
import { getCache, isCacheEnabled } from '../cache';
5

6
import {
7
  maybeCoerceToGeminiFormat,
8
  type GeminiApiResponse,
9
  type GeminiResponseData,
10
  GeminiErrorResponse,
11
  Palm2ApiResponse,
12
} from './vertexUtil';
13

14
import type { GoogleAuth } from 'google-auth-library';
15

16
import type { ApiProvider, EnvOverrides, ProviderResponse, TokenUsage } from '../types.js';
17

18
let cachedAuth: GoogleAuth | undefined;
19
async function getGoogleClient() {
20
  if (!cachedAuth) {
21
    let GoogleAuth;
22
    try {
23
      const importedModule = await import('google-auth-library');
24
      GoogleAuth = importedModule.GoogleAuth;
25
    } catch (err) {
26
      throw new Error(
27
        'The google-auth-library package is required as a peer dependency. Please install it in your project or globally.',
28
      );
29
    }
30
    cachedAuth = new GoogleAuth({
31
      scopes: 'https://www.googleapis.com/auth/cloud-platform',
32
    });
33
  }
34
  const client = await cachedAuth.getClient();
35
  const projectId = await cachedAuth.getProjectId();
36
  return { client, projectId };
37
}
38

39
interface VertexCompletionOptions {
40
  apiKey?: string;
41
  apiHost?: string;
42
  projectId?: string;
43
  region?: string;
44
  publisher?: string;
45

46
  context?: string;
47
  examples?: { input: string; output: string }[];
48
  safetySettings?: { category: string; probability: string }[];
49
  stopSequence?: string[];
50
  temperature?: number;
51
  maxOutputTokens?: number;
52
  topP?: number;
53
  topK?: number;
54
}
55

56
class VertexGenericProvider implements ApiProvider {
57
  modelName: string;
58

59
  config: VertexCompletionOptions;
60
  env?: EnvOverrides;
61

62
  constructor(
63
    modelName: string,
64
    options: { config?: VertexCompletionOptions; id?: string; env?: EnvOverrides } = {},
65
  ) {
66
    const { config, id, env } = options;
67
    this.env = env;
68
    this.modelName = modelName;
69
    this.config = config || {};
70
    this.id = id ? () => id : this.id;
71
  }
72

73
  id(): string {
74
    return `vertex:${this.modelName}`;
75
  }
76

77
  toString(): string {
78
    return `[Google Vertex Provider ${this.modelName}]`;
79
  }
80

81
  getApiHost(): string | undefined {
82
    return (
83
      this.config.apiHost ||
84
      this.env?.VERTEX_API_HOST ||
85
      process.env.VERTEX_API_HOST ||
86
      `${this.getRegion()}-aiplatform.googleapis.com`
87
    );
88
  }
89

90
  async getProjectId() {
91
    return (
92
      (await getGoogleClient()).projectId ||
93
      this.config.projectId ||
94
      this.env?.VERTEX_PROJECT_ID ||
95
      process.env.VERTEX_PROJECT_ID
96
    );
97
  }
98

99
  getApiKey(): string | undefined {
100
    return this.config.apiKey || this.env?.VERTEX_API_KEY || process.env.VERTEX_API_KEY;
101
  }
102

103
  getRegion(): string {
104
    return (
105
      this.config.region || this.env?.VERTEX_REGION || process.env.VERTEX_REGION || 'us-central1'
106
    );
107
  }
108

109
  getPublisher(): string | undefined {
110
    return (
111
      this.config.publisher ||
112
      this.env?.VERTEX_PUBLISHER ||
113
      process.env.VERTEX_PUBLISHER ||
114
      'google'
115
    );
116
  }
117

118
  // @ts-ignore: Prompt is not used in this implementation
119
  async callApi(prompt: string): Promise<ProviderResponse> {
120
    throw new Error('Not implemented');
121
  }
122
}
123

124
export class VertexChatProvider extends VertexGenericProvider {
125
  // TODO(ian): Completion models 
126
  // https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions
127
  static CHAT_MODELS = [
128
    'chat-bison',
129
    'chat-bison@001',
130
    'chat-bison@002',
131
    'chat-bison-32k',
132
    'chat-bison-32k@001',
133
    'chat-bison-32k@002',
134
    'codechat-bison',
135
    'codechat-bison@001',
136
    'codechat-bison@002',
137
    'codechat-bison-32k',
138
    'codechat-bison-32k@001',
139
    'codechat-bison-32k@002',
140
    'gemini-pro',
141
    'gemini-ultra',
142
    'gemini-1.0-pro-vision',
143
    'gemini-1.0-pro-vision-001',
144
    'gemini-1.0-pro',
145
    'gemini-1.0-pro-001',
146
    'gemini-pro-vision',
147
    'gemini-1.5-pro-latest',
148
    'aqa',
149
  ];
150

151
  constructor(
152
    modelName: string,
153
    options: { config?: VertexCompletionOptions; id?: string; env?: EnvOverrides } = {},
154
  ) {
155
    if (!VertexChatProvider.CHAT_MODELS.includes(modelName)) {
156
      logger.warn(`Using unknown Google Vertex chat model: ${modelName}`);
157
    }
158
    super(modelName, options);
159
  }
160

161
  async callApi(prompt: string): Promise<ProviderResponse> {
162
    if (!this.getApiKey()) {
163
      throw new Error(
164
        'Google Vertex API key is not set. Set the VERTEX_API_KEY environment variable or add `apiKey` to the provider config. You can get an API token by running `gcloud auth print-access-token`',
165
      );
166
    }
167
    if (!this.getProjectId()) {
168
      throw new Error(
169
        'Google Vertex project ID is not set. Set the VERTEX_PROJECT_ID environment variable or add `projectId` to the provider config.',
170
      );
171
    }
172

173
    if (this.modelName.includes('gemini')) {
174
      return this.callGeminiApi(prompt);
175
    }
176
    return this.callPalm2Api(prompt);
177
  }
178

179
  async callGeminiApi(prompt: string): Promise<ProviderResponse> {
180
    // https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#gemini-pro
181
    let contents = parseChatPrompt(prompt, {
182
      role: 'user',
183
      parts: {
184
        text: prompt,
185
      },
186
    });
187
    const { contents: updatedContents, coerced } = maybeCoerceToGeminiFormat(contents);
188
    if (coerced) {
189
      logger.debug(`Coerced JSON prompt to Gemini format: ${JSON.stringify(contents)}`);
190
      contents = updatedContents;
191
    }
192

193
    // https://ai.google.dev/api/rest/v1/models/streamGenerateContent
194
    const body = {
195
      contents,
196
      generationConfig: {
197
        context: this.config.context,
198
        examples: this.config.examples,
199
        stopSequence: this.config.stopSequence,
200
        temperature: this.config.temperature,
201
        maxOutputTokens: this.config.maxOutputTokens,
202
        topP: this.config.topP,
203
        topK: this.config.topK,
204
      },
205
      safetySettings: this.config.safetySettings,
206
    };
207
    logger.debug(`Preparing to call Google Vertex API (Gemini) with body: ${JSON.stringify(body)}`);
208

209
    const cache = await getCache();
210
    const cacheKey = `vertex:gemini:${JSON.stringify(body)}`;
211

212
    let cachedResponse;
213
    if (isCacheEnabled()) {
214
      cachedResponse = await cache.get(cacheKey);
215
      if (cachedResponse) {
216
        logger.debug(`Returning cached response for prompt: ${prompt}`);
217
        const parsedCachedResponse = JSON.parse(cachedResponse as string);
218
        const tokenUsage = parsedCachedResponse.tokenUsage as TokenUsage;
219
        if (tokenUsage) {
220
          tokenUsage.cached = tokenUsage.total;
221
        }
222
        return { ...parsedCachedResponse, cached: true };
223
      }
224
    }
225

226
    let data;
227
    try {
228
      const { client, projectId } = await getGoogleClient();
229
      const url = `https://${this.getApiHost()}/v1/projects/${projectId}/locations/${this.getRegion()}/publishers/${this.getPublisher()}/models/${
230
        this.modelName
231
      }:streamGenerateContent`;
232
      const res = await client.request({
233
        url,
234
        method: 'POST',
235
        data: body,
236
      });
237
      data = res.data as GeminiApiResponse;
238
    } catch (err) {
239
      return {
240
        error: `API call error: ${JSON.stringify(err)}`,
241
      };
242
    }
243

244
    logger.debug(`Gemini API response: ${JSON.stringify(data)}`);
245
    try {
246
      const dataWithError = data as GeminiErrorResponse[];
247
      const error = dataWithError[0].error;
248
      if (error) {
249
        return {
250
          error: `Error ${error.code}: ${error.message}`,
251
        };
252
      }
253
      const dataWithResponse = data as GeminiResponseData[];
254
      const output = dataWithResponse
255
        .map((datum: GeminiResponseData) => {
256
          const part = datum.candidates[0].content.parts[0];
257
          if ('text' in part) {
258
            return part.text;
259
          }
260
          return JSON.stringify(part);
261
        })
262
        .join('');
263
      const lastData = dataWithResponse[dataWithResponse.length - 1];
264
      const tokenUsage = {
265
        total: lastData.usageMetadata?.totalTokenCount || 0,
266
        prompt: lastData.usageMetadata?.promptTokenCount || 0,
267
        completion: lastData.usageMetadata?.candidatesTokenCount || 0,
268
      };
269
      const response = {
270
        cached: false,
271
        output,
272
        tokenUsage,
273
      };
274

275
      if (isCacheEnabled()) {
276
        await cache.set(cacheKey, JSON.stringify(response));
277
      }
278

279
      return response;
280
    } catch (err) {
281
      return {
282
        error: `Gemini API response error: ${String(err)}: ${JSON.stringify(data)}`,
283
      };
284
    }
285
  }
286

287
  async callPalm2Api(prompt: string): Promise<ProviderResponse> {
288
    const instances = parseChatPrompt(prompt, [
289
      {
290
        messages: [
291
          {
292
            author: 'user',
293
            content: prompt,
294
          },
295
        ],
296
      },
297
    ]);
298

299
    const body = {
300
      instances,
301
      parameters: {
302
        context: this.config.context,
303
        examples: this.config.examples,
304
        safetySettings: this.config.safetySettings,
305
        stopSequence: this.config.stopSequence,
306
        temperature: this.config.temperature,
307
        maxOutputTokens: this.config.maxOutputTokens,
308
        topP: this.config.topP,
309
        topK: this.config.topK,
310
      },
311
    };
312
    logger.debug(`Calling Vertex Palm2 API: ${JSON.stringify(body)}`);
313

314
    const cache = await getCache();
315
    const cacheKey = `vertex:palm2:${JSON.stringify(body)}`;
316

317
    let cachedResponse;
318
    if (isCacheEnabled()) {
319
      cachedResponse = await cache.get(cacheKey);
320
      if (cachedResponse) {
321
        logger.debug(`Returning cached response for prompt: ${prompt}`);
322
        const parsedCachedResponse = JSON.parse(cachedResponse as string);
323
        const tokenUsage = parsedCachedResponse.tokenUsage as TokenUsage;
324
        if (tokenUsage) {
325
          tokenUsage.cached = tokenUsage.total;
326
        }
327
        return { ...parsedCachedResponse, cached: true };
328
      }
329
    }
330

331
    let data: Palm2ApiResponse;
332
    try {
333
      const { client, projectId } = await getGoogleClient();
334
      const url = `https://${this.getApiHost()}/v1/projects/${projectId}/locations/${this.getRegion()}/publishers/${this.getPublisher()}/models/${
335
        this.modelName
336
      }:predict`;
337
      const res = await client.request<Palm2ApiResponse>({
338
        url,
339
        method: 'POST',
340
        headers: {
341
          'Content-Type': 'application/json',
342
          Authorization: `Bearer ${this.getApiKey()}`,
343
        },
344
        data: body,
345
      });
346
      data = res.data;
347
    } catch (err) {
348
      return {
349
        error: `API call error: ${JSON.stringify(err)}`,
350
      };
351
    }
352

353
    logger.debug(`Vertex Palm2 API response: ${JSON.stringify(data)}`);
354
    try {
355
      if (data.error) {
356
        return {
357
          error: `Error ${data.error.code}: ${data.error.message}`,
358
        };
359
      }
360
      const output = data.predictions?.[0].candidates[0].content;
361

362
      const response = {
363
        output,
364
        cached: false,
365
      };
366

367
      if (isCacheEnabled()) {
368
        await cache.set(cacheKey, JSON.stringify(response));
369
      }
370

371
      return response;
372
    } catch (err) {
373
      return {
374
        error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,
375
      };
376
    }
377
  }
378
}
379

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.