1
import logger from '../logger';
2
import { fetchWithCache } from '../cache';
3
import { parseChatPrompt, REQUEST_TIMEOUT_MS } from './shared';
4
import { getCache, isCacheEnabled } from '../cache';
7
maybeCoerceToGeminiFormat,
8
type GeminiApiResponse,
9
type GeminiResponseData,
14
import type { GoogleAuth } from 'google-auth-library';
16
import type { ApiProvider, EnvOverrides, ProviderResponse, TokenUsage } from '../types.js';
18
let cachedAuth: GoogleAuth | undefined;
19
async function getGoogleClient() {
23
const importedModule = await import('google-auth-library');
24
GoogleAuth = importedModule.GoogleAuth;
27
'The google-auth-library package is required as a peer dependency. Please install it in your project or globally.',
30
cachedAuth = new GoogleAuth({
31
scopes: 'https://www.googleapis.com/auth/cloud-platform',
34
const client = await cachedAuth.getClient();
35
const projectId = await cachedAuth.getProjectId();
36
return { client, projectId };
39
interface VertexCompletionOptions {
47
examples?: { input: string; output: string }[];
48
safetySettings?: { category: string; probability: string }[];
49
stopSequence?: string[];
51
maxOutputTokens?: number;
56
class VertexGenericProvider implements ApiProvider {
59
config: VertexCompletionOptions;
64
options: { config?: VertexCompletionOptions; id?: string; env?: EnvOverrides } = {},
66
const { config, id, env } = options;
68
this.modelName = modelName;
69
this.config = config || {};
70
this.id = id ? () => id : this.id;
74
return `vertex:${this.modelName}`;
78
return `[Google Vertex Provider ${this.modelName}]`;
81
getApiHost(): string | undefined {
83
this.config.apiHost ||
84
this.env?.VERTEX_API_HOST ||
85
process.env.VERTEX_API_HOST ||
86
`${this.getRegion()}-aiplatform.googleapis.com`
90
async getProjectId() {
92
(await getGoogleClient()).projectId ||
93
this.config.projectId ||
94
this.env?.VERTEX_PROJECT_ID ||
95
process.env.VERTEX_PROJECT_ID
99
getApiKey(): string | undefined {
100
return this.config.apiKey || this.env?.VERTEX_API_KEY || process.env.VERTEX_API_KEY;
103
getRegion(): string {
105
this.config.region || this.env?.VERTEX_REGION || process.env.VERTEX_REGION || 'us-central1'
109
getPublisher(): string | undefined {
111
this.config.publisher ||
112
this.env?.VERTEX_PUBLISHER ||
113
process.env.VERTEX_PUBLISHER ||
118
// @ts-ignore: Prompt is not used in this implementation
119
async callApi(prompt: string): Promise<ProviderResponse> {
120
throw new Error('Not implemented');
124
export class VertexChatProvider extends VertexGenericProvider {
125
// TODO(ian): Completion models
126
// https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#gemini-model-versions
127
static CHAT_MODELS = [
132
'chat-bison-32k@001',
133
'chat-bison-32k@002',
135
'codechat-bison@001',
136
'codechat-bison@002',
137
'codechat-bison-32k',
138
'codechat-bison-32k@001',
139
'codechat-bison-32k@002',
142
'gemini-1.0-pro-vision',
143
'gemini-1.0-pro-vision-001',
145
'gemini-1.0-pro-001',
147
'gemini-1.5-pro-latest',
153
options: { config?: VertexCompletionOptions; id?: string; env?: EnvOverrides } = {},
155
if (!VertexChatProvider.CHAT_MODELS.includes(modelName)) {
156
logger.warn(`Using unknown Google Vertex chat model: ${modelName}`);
158
super(modelName, options);
161
async callApi(prompt: string): Promise<ProviderResponse> {
162
if (!this.getApiKey()) {
164
'Google Vertex API key is not set. Set the VERTEX_API_KEY environment variable or add `apiKey` to the provider config. You can get an API token by running `gcloud auth print-access-token`',
167
if (!this.getProjectId()) {
169
'Google Vertex project ID is not set. Set the VERTEX_PROJECT_ID environment variable or add `projectId` to the provider config.',
173
if (this.modelName.includes('gemini')) {
174
return this.callGeminiApi(prompt);
176
return this.callPalm2Api(prompt);
179
async callGeminiApi(prompt: string): Promise<ProviderResponse> {
180
// https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#gemini-pro
181
let contents = parseChatPrompt(prompt, {
187
const { contents: updatedContents, coerced } = maybeCoerceToGeminiFormat(contents);
189
logger.debug(`Coerced JSON prompt to Gemini format: ${JSON.stringify(contents)}`);
190
contents = updatedContents;
193
// https://ai.google.dev/api/rest/v1/models/streamGenerateContent
197
context: this.config.context,
198
examples: this.config.examples,
199
stopSequence: this.config.stopSequence,
200
temperature: this.config.temperature,
201
maxOutputTokens: this.config.maxOutputTokens,
202
topP: this.config.topP,
203
topK: this.config.topK,
205
safetySettings: this.config.safetySettings,
207
logger.debug(`Preparing to call Google Vertex API (Gemini) with body: ${JSON.stringify(body)}`);
209
const cache = await getCache();
210
const cacheKey = `vertex:gemini:${JSON.stringify(body)}`;
213
if (isCacheEnabled()) {
214
cachedResponse = await cache.get(cacheKey);
215
if (cachedResponse) {
216
logger.debug(`Returning cached response for prompt: ${prompt}`);
217
const parsedCachedResponse = JSON.parse(cachedResponse as string);
218
const tokenUsage = parsedCachedResponse.tokenUsage as TokenUsage;
220
tokenUsage.cached = tokenUsage.total;
222
return { ...parsedCachedResponse, cached: true };
228
const { client, projectId } = await getGoogleClient();
229
const url = `https://${this.getApiHost()}/v1/projects/${projectId}/locations/${this.getRegion()}/publishers/${this.getPublisher()}/models/${
231
}:streamGenerateContent`;
232
const res = await client.request({
237
data = res.data as GeminiApiResponse;
240
error: `API call error: ${JSON.stringify(err)}`,
244
logger.debug(`Gemini API response: ${JSON.stringify(data)}`);
246
const dataWithError = data as GeminiErrorResponse[];
247
const error = dataWithError[0].error;
250
error: `Error ${error.code}: ${error.message}`,
253
const dataWithResponse = data as GeminiResponseData[];
254
const output = dataWithResponse
255
.map((datum: GeminiResponseData) => {
256
const part = datum.candidates[0].content.parts[0];
257
if ('text' in part) {
260
return JSON.stringify(part);
263
const lastData = dataWithResponse[dataWithResponse.length - 1];
265
total: lastData.usageMetadata?.totalTokenCount || 0,
266
prompt: lastData.usageMetadata?.promptTokenCount || 0,
267
completion: lastData.usageMetadata?.candidatesTokenCount || 0,
275
if (isCacheEnabled()) {
276
await cache.set(cacheKey, JSON.stringify(response));
282
error: `Gemini API response error: ${String(err)}: ${JSON.stringify(data)}`,
287
async callPalm2Api(prompt: string): Promise<ProviderResponse> {
288
const instances = parseChatPrompt(prompt, [
302
context: this.config.context,
303
examples: this.config.examples,
304
safetySettings: this.config.safetySettings,
305
stopSequence: this.config.stopSequence,
306
temperature: this.config.temperature,
307
maxOutputTokens: this.config.maxOutputTokens,
308
topP: this.config.topP,
309
topK: this.config.topK,
312
logger.debug(`Calling Vertex Palm2 API: ${JSON.stringify(body)}`);
314
const cache = await getCache();
315
const cacheKey = `vertex:palm2:${JSON.stringify(body)}`;
318
if (isCacheEnabled()) {
319
cachedResponse = await cache.get(cacheKey);
320
if (cachedResponse) {
321
logger.debug(`Returning cached response for prompt: ${prompt}`);
322
const parsedCachedResponse = JSON.parse(cachedResponse as string);
323
const tokenUsage = parsedCachedResponse.tokenUsage as TokenUsage;
325
tokenUsage.cached = tokenUsage.total;
327
return { ...parsedCachedResponse, cached: true };
331
let data: Palm2ApiResponse;
333
const { client, projectId } = await getGoogleClient();
334
const url = `https://${this.getApiHost()}/v1/projects/${projectId}/locations/${this.getRegion()}/publishers/${this.getPublisher()}/models/${
337
const res = await client.request<Palm2ApiResponse>({
341
'Content-Type': 'application/json',
342
Authorization: `Bearer ${this.getApiKey()}`,
349
error: `API call error: ${JSON.stringify(err)}`,
353
logger.debug(`Vertex Palm2 API response: ${JSON.stringify(data)}`);
357
error: `Error ${data.error.code}: ${data.error.message}`,
360
const output = data.predictions?.[0].candidates[0].content;
367
if (isCacheEnabled()) {
368
await cache.set(cacheKey, JSON.stringify(response));
374
error: `API response error: ${String(err)}: ${JSON.stringify(data)}`,