promptfoo

Форк
0
/
openai.ts 
848 строк · 25.2 Кб
1
import OpenAI from 'openai';
2

3
import logger from '../logger';
4
import { fetchWithCache, getCache, isCacheEnabled } from '../cache';
5
import { REQUEST_TIMEOUT_MS, parseChatPrompt, toTitleCase } from './shared';
6
import { OpenAiFunction, OpenAiTool } from './openaiUtil';
7

8
import type {
9
  ApiProvider,
10
  CallApiContextParams,
11
  CallApiOptionsParams,
12
  EnvOverrides,
13
  ProviderEmbeddingResponse,
14
  ProviderResponse,
15
  TokenUsage,
16
} from '../types.js';
17

18
interface OpenAiSharedOptions {
19
  apiKey?: string;
20
  apiKeyEnvar?: string;
21
  apiHost?: string;
22
  apiBaseUrl?: string;
23
  organization?: string;
24
  cost?: number;
25
}
26

27
type OpenAiCompletionOptions = OpenAiSharedOptions & {
28
  temperature?: number;
29
  max_tokens?: number;
30
  top_p?: number;
31
  frequency_penalty?: number;
32
  presence_penalty?: number;
33
  best_of?: number;
34
  functions?: OpenAiFunction[];
35
  function_call?: 'none' | 'auto' | { name: string };
36
  tools?: OpenAiTool[];
37
  tool_choice?: 'none' | 'auto' | { type: 'function'; function?: { name: string } };
38
  response_format?: { type: 'json_object' };
39
  stop?: string[];
40
  seed?: number;
41
  passthrough?: object;
42
};
43

44
function failApiCall(err: any) {
45
  if (err instanceof OpenAI.APIError) {
46
    return {
47
      error: `API error: ${err.type} ${err.message}`,
48
    };
49
  }
50
  return {
51
    error: `API error: ${String(err)}`,
52
  };
53
}
54

55
function getTokenUsage(data: any, cached: boolean): Partial<TokenUsage> {
56
  if (data.usage) {
57
    if (cached) {
58
      return { cached: data.usage.total_tokens, total: data.usage.total_tokens };
59
    } else {
60
      return {
61
        total: data.usage.total_tokens,
62
        prompt: data.usage.prompt_tokens || 0,
63
        completion: data.usage.completion_tokens || 0,
64
      };
65
    }
66
  }
67
  return {};
68
}
69

70
export class OpenAiGenericProvider implements ApiProvider {
71
  modelName: string;
72

73
  config: OpenAiSharedOptions;
74
  env?: EnvOverrides;
75

76
  constructor(
77
    modelName: string,
78
    options: { config?: OpenAiSharedOptions; id?: string; env?: EnvOverrides } = {},
79
  ) {
80
    const { config, id, env } = options;
81
    this.env = env;
82
    this.modelName = modelName;
83
    this.config = config || {};
84
    this.id = id ? () => id : this.id;
85
  }
86

87
  id(): string {
88
    return `openai:${this.modelName}`;
89
  }
90

91
  toString(): string {
92
    return `[OpenAI Provider ${this.modelName}]`;
93
  }
94

95
  getOrganization(): string | undefined {
96
    return (
97
      this.config.organization || this.env?.OPENAI_ORGANIZATION || process.env.OPENAI_ORGANIZATION
98
    );
99
  }
100

101
  getApiUrlDefault(): string {
102
    return 'https://api.openai.com/v1';
103
  }
104

105
  getApiUrl(): string {
106
    const apiHost = this.config.apiHost || this.env?.OPENAI_API_HOST || process.env.OPENAI_API_HOST;
107
    if (apiHost) {
108
      return `https://${apiHost}/v1`;
109
    }
110
    return (
111
      this.config.apiBaseUrl ||
112
      this.env?.OPENAI_API_BASE_URL ||
113
      process.env.OPENAI_API_BASE_URL ||
114
      this.getApiUrlDefault()
115
    );
116
  }
117

118
  getApiKey(): string | undefined {
119
    return (
120
      this.config.apiKey ||
121
      (this.config?.apiKeyEnvar
122
        ? process.env[this.config.apiKeyEnvar] ||
123
          this.env?.[this.config.apiKeyEnvar as keyof EnvOverrides]
124
        : undefined) ||
125
      this.env?.OPENAI_API_KEY ||
126
      process.env.OPENAI_API_KEY
127
    );
128
  }
129

130
  // @ts-ignore: Params are not used in this implementation
131
  async callApi(
132
    prompt: string,
133
    context?: CallApiContextParams,
134
    callApiOptions?: CallApiOptionsParams,
135
  ): Promise<ProviderResponse> {
136
    throw new Error('Not implemented');
137
  }
138
}
139

140
export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
141
  async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
142
    if (!this.getApiKey()) {
143
      throw new Error('OpenAI API key must be set for similarity comparison');
144
    }
145

146
    const body = {
147
      input: text,
148
      model: this.modelName,
149
    };
150
    let data,
151
      cached = false;
152
    try {
153
      ({ data, cached } = (await fetchWithCache(
154
        `${this.getApiUrl()}/embeddings`,
155
        {
156
          method: 'POST',
157
          headers: {
158
            'Content-Type': 'application/json',
159
            Authorization: `Bearer ${this.getApiKey()}`,
160
            ...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
161
          },
162
          body: JSON.stringify(body),
163
        },
164
        REQUEST_TIMEOUT_MS,
165
      )) as unknown as any);
166
    } catch (err) {
167
      logger.error(`API call error: ${err}`);
168
      throw err;
169
    }
170
    logger.debug(`\tOpenAI embeddings API response: ${JSON.stringify(data)}`);
171

172
    try {
173
      const embedding = data?.data?.[0]?.embedding;
174
      if (!embedding) {
175
        throw new Error('No embedding found in OpenAI embeddings API response');
176
      }
177
      return {
178
        embedding,
179
        tokenUsage: getTokenUsage(data, cached),
180
      };
181
    } catch (err) {
182
      logger.error(data.error.message);
183
      throw err;
184
    }
185
  }
186
}
187

188
export class OpenAiCompletionProvider extends OpenAiGenericProvider {
189
  static OPENAI_COMPLETION_MODELS = [
190
    {
191
      id: 'gpt-3.5-turbo-instruct',
192
      cost: {
193
        input: 0.0015 / 1000,
194
        output: 0.002 / 1000,
195
      },
196
    },
197
    {
198
      id: 'gpt-3.5-turbo-instruct-0914',
199
      cost: {
200
        input: 0.0015 / 1000,
201
        output: 0.002 / 1000,
202
      },
203
    },
204
    {
205
      id: 'text-davinci-003',
206
    },
207
    {
208
      id: 'text-davinci-002',
209
    },
210
    {
211
      id: 'text-curie-001',
212
    },
213
    {
214
      id: 'text-babbage-001',
215
    },
216
    {
217
      id: 'text-ada-001',
218
    },
219
  ];
220

221
  static OPENAI_COMPLETION_MODEL_NAMES = OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.map(
222
    (model) => model.id,
223
  );
224

225
  config: OpenAiCompletionOptions;
226

227
  constructor(
228
    modelName: string,
229
    options: { config?: OpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {},
230
  ) {
231
    super(modelName, options);
232
    this.config = options.config || {};
233
    if (
234
      !OpenAiCompletionProvider.OPENAI_COMPLETION_MODEL_NAMES.includes(modelName) &&
235
      this.getApiUrl() === this.getApiUrlDefault()
236
    ) {
237
      logger.warn(`FYI: Using unknown OpenAI completion model: ${modelName}`);
238
    }
239
  }
240

241
  async callApi(
242
    prompt: string,
243
    context?: CallApiContextParams,
244
    callApiOptions?: CallApiOptionsParams,
245
  ): Promise<ProviderResponse> {
246
    if (!this.getApiKey()) {
247
      throw new Error(
248
        'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
249
      );
250
    }
251

252
    let stop: string;
253
    try {
254
      stop = process.env.OPENAI_STOP
255
        ? JSON.parse(process.env.OPENAI_STOP)
256
        : this.config?.stop || ['<|im_end|>', '<|endoftext|>'];
257
    } catch (err) {
258
      throw new Error(`OPENAI_STOP is not a valid JSON string: ${err}`);
259
    }
260
    const body = {
261
      model: this.modelName,
262
      prompt,
263
      seed: this.config.seed || 0,
264
      max_tokens: this.config.max_tokens ?? parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
265
      temperature: this.config.temperature ?? parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
266
      top_p: this.config.top_p ?? parseFloat(process.env.OPENAI_TOP_P || '1'),
267
      presence_penalty:
268
        this.config.presence_penalty ?? parseFloat(process.env.OPENAI_PRESENCE_PENALTY || '0'),
269
      frequency_penalty:
270
        this.config.frequency_penalty ?? parseFloat(process.env.OPENAI_FREQUENCY_PENALTY || '0'),
271
      best_of: this.config.best_of ?? parseInt(process.env.OPENAI_BEST_OF || '1'),
272
      ...(callApiOptions?.includeLogProbs ? { logprobs: callApiOptions.includeLogProbs } : {}),
273
      ...(stop ? { stop } : {}),
274
      ...(this.config.passthrough || {}),
275
    };
276
    logger.debug(`Calling OpenAI API: ${JSON.stringify(body)}`);
277
    let data,
278
      cached = false;
279
    try {
280
      ({ data, cached } = (await fetchWithCache(
281
        `${this.getApiUrl()}/completions`,
282
        {
283
          method: 'POST',
284
          headers: {
285
            'Content-Type': 'application/json',
286
            Authorization: `Bearer ${this.getApiKey()}`,
287
            ...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
288
          },
289
          body: JSON.stringify(body),
290
        },
291
        REQUEST_TIMEOUT_MS,
292
      )) as unknown as any);
293
    } catch (err) {
294
      return {
295
        error: `API call error: ${String(err)}`,
296
      };
297
    }
298
    logger.debug(`\tOpenAI completions API response: ${JSON.stringify(data)}`);
299
    if (data.error) {
300
      return {
301
        error: formatOpenAiError(data),
302
      };
303
    }
304
    try {
305
      return {
306
        output: data.choices[0].text,
307
        tokenUsage: getTokenUsage(data, cached),
308
        cached,
309
        cost: calculateCost(
310
          this.modelName,
311
          this.config,
312
          data.usage?.prompt_tokens,
313
          data.usage?.completion_tokens,
314
        ),
315
      };
316
    } catch (err) {
317
      return {
318
        error: `API error: ${String(err)}: ${JSON.stringify(data)}`,
319
      };
320
    }
321
  }
322
}
323

324
export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
325
  static OPENAI_CHAT_MODELS = [
326
    ...['gpt-4', 'gpt-4-0314', 'gpt-4-0613'].map((model) => ({
327
      id: model,
328
      cost: {
329
        input: 0.03 / 1000,
330
        output: 0.06 / 1000,
331
      },
332
    })),
333
    ...[
334
      'gpt-4-1106-preview',
335
      'gpt-4-1106-vision-preview',
336
      'gpt-4-0125-preview',
337
      'gpt-4-turbo-preview',
338
    ].map((model) => ({
339
      id: model,
340
      cost: {
341
        input: 0.01 / 1000,
342
        output: 0.03 / 1000,
343
      },
344
    })),
345
    ...['gpt-4-32k', 'gpt-4-32k-0314'].map((model) => ({
346
      id: model,
347
      cost: {
348
        input: 0.06 / 1000,
349
        output: 0.12 / 1000,
350
      },
351
    })),
352
    ...[
353
      'gpt-3.5-turbo',
354
      'gpt-3.5-turbo-0301',
355
      'gpt-3.5-turbo-0613',
356
      'gpt-3.5-turbo-1106',
357
      'gpt-3.5-turbo-0125',
358
      'gpt-3.5-turbo-16k',
359
      'gpt-3.5-turbo-16k-0613',
360
    ].map((model) => ({
361
      id: model,
362
      cost: {
363
        input: 0.0005 / 1000,
364
        output: 0.0015 / 1000,
365
      },
366
    })),
367
  ];
368

369
  static OPENAI_CHAT_MODEL_NAMES = OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.map(
370
    (model) => model.id,
371
  );
372

373
  config: OpenAiCompletionOptions;
374

375
  constructor(
376
    modelName: string,
377
    options: { config?: OpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {},
378
  ) {
379
    if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODEL_NAMES.includes(modelName)) {
380
      logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
381
    }
382
    super(modelName, options);
383
    this.config = options.config || {};
384
  }
385

386
  async callApi(
387
    prompt: string,
388
    context?: CallApiContextParams,
389
    callApiOptions?: CallApiOptionsParams,
390
  ): Promise<ProviderResponse> {
391
    if (!this.getApiKey()) {
392
      throw new Error(
393
        'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
394
      );
395
    }
396

397
    const messages = parseChatPrompt(prompt, [{ role: 'user', content: prompt }]);
398

399
    let stop: string;
400
    try {
401
      stop = process.env.OPENAI_STOP
402
        ? JSON.parse(process.env.OPENAI_STOP)
403
        : this.config?.stop || [];
404
    } catch (err) {
405
      throw new Error(`OPENAI_STOP is not a valid JSON string: ${err}`);
406
    }
407
    const body = {
408
      model: this.modelName,
409
      messages: messages,
410
      seed: this.config.seed || 0,
411
      max_tokens: this.config.max_tokens ?? parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
412
      temperature: this.config.temperature ?? parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
413
      top_p: this.config.top_p ?? parseFloat(process.env.OPENAI_TOP_P || '1'),
414
      presence_penalty:
415
        this.config.presence_penalty ?? parseFloat(process.env.OPENAI_PRESENCE_PENALTY || '0'),
416
      frequency_penalty:
417
        this.config.frequency_penalty ?? parseFloat(process.env.OPENAI_FREQUENCY_PENALTY || '0'),
418
      ...(this.config.functions ? { functions: this.config.functions } : {}),
419
      ...(this.config.function_call ? { function_call: this.config.function_call } : {}),
420
      ...(this.config.tools ? { tools: this.config.tools } : {}),
421
      ...(this.config.tool_choice ? { tool_choice: this.config.tool_choice } : {}),
422
      ...(this.config.response_format ? { response_format: this.config.response_format } : {}),
423
      ...(callApiOptions?.includeLogProbs ? { logprobs: callApiOptions.includeLogProbs } : {}),
424
      ...(this.config.stop ? { stop: this.config.stop } : {}),
425
      ...(this.config.passthrough || {}),
426
    };
427
    logger.debug(`Calling OpenAI API: ${JSON.stringify(body)}`);
428

429
    let data,
430
      cached = false;
431
    try {
432
      ({ data, cached } = (await fetchWithCache(
433
        `${this.getApiUrl()}/chat/completions`,
434
        {
435
          method: 'POST',
436
          headers: {
437
            'Content-Type': 'application/json',
438
            Authorization: `Bearer ${this.getApiKey()}`,
439
            ...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
440
          },
441
          body: JSON.stringify(body),
442
        },
443
        REQUEST_TIMEOUT_MS,
444
      )) as unknown as { data: any; cached: boolean });
445
    } catch (err) {
446
      return {
447
        error: `API call error: ${String(err)}`,
448
      };
449
    }
450

451
    logger.debug(`\tOpenAI chat completions API response: ${JSON.stringify(data)}`);
452
    if (data.error) {
453
      return {
454
        error: formatOpenAiError(data),
455
      };
456
    }
457
    try {
458
      const message = data.choices[0].message;
459
      const output =
460
        message.content === null ? message.function_call || message.tool_calls : message.content;
461
      const logProbs = data.choices[0].logprobs?.content?.map(
462
        (logProbObj: { token: string; logprob: number }) => logProbObj.logprob,
463
      );
464

465
      return {
466
        output,
467
        tokenUsage: getTokenUsage(data, cached),
468
        cached,
469
        logProbs,
470
        cost: calculateCost(
471
          this.modelName,
472
          this.config,
473
          data.usage?.prompt_tokens,
474
          data.usage?.completion_tokens,
475
        ),
476
      };
477
    } catch (err) {
478
      return {
479
        error: `API error: ${String(err)}: ${JSON.stringify(data)}`,
480
      };
481
    }
482
  }
483
}
484

485
function formatOpenAiError(data: { error: { message: string; type?: string; code?: string } }) {
486
  return (
487
    `API error: ${data.error.message}` +
488
    (data.error.type ? `, Type: ${data.error.type}` : '') +
489
    (data.error.code ? `, Code: ${data.error.code}` : '')
490
  );
491
}
492

493
function calculateCost(
494
  modelName: string,
495
  config: OpenAiSharedOptions,
496
  promptTokens?: number,
497
  completionTokens?: number,
498
): number | undefined {
499
  if (!promptTokens || !completionTokens) {
500
    return undefined;
501
  }
502

503
  const model = [
504
    ...OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS,
505
    ...OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS,
506
  ].find((m) => m.id === modelName);
507
  if (!model || !model.cost) {
508
    return undefined;
509
  }
510

511
  const inputCost = config.cost ?? model.cost.input;
512
  const outputCost = config.cost ?? model.cost.output;
513
  return inputCost * promptTokens + outputCost * completionTokens || undefined;
514
}
515

516
interface AssistantMessagesResponseDataContent {
517
  type: string;
518
  text?: {
519
    value: string;
520
  };
521
}
522

523
interface AssistantMessagesResponseData {
524
  data: {
525
    role: string;
526
    content?: AssistantMessagesResponseDataContent[];
527
  }[];
528
}
529

530
type OpenAiAssistantOptions = OpenAiSharedOptions & {
531
  modelName?: string;
532
  instructions?: string;
533
  tools?: OpenAI.Beta.Threads.ThreadCreateAndRunParams['tools'];
534
  /**
535
   * If set, automatically call these functions when the assistant activates
536
   * these function tools.
537
   */
538
  functionToolCallbacks?: Record<
539
    OpenAI.FunctionDefinition['name'],
540
    (arg: string) => Promise<string>
541
  >;
542
  metadata?: object[];
543
};
544

545
export class OpenAiAssistantProvider extends OpenAiGenericProvider {
546
  assistantId: string;
547
  assistantConfig: OpenAiAssistantOptions;
548

549
  constructor(
550
    assistantId: string,
551
    options: { config?: OpenAiAssistantOptions; id?: string; env?: EnvOverrides } = {},
552
  ) {
553
    super(assistantId, options);
554
    this.assistantConfig = options.config || {};
555
    this.assistantId = assistantId;
556
  }
557

558
  async callApi(prompt: string): Promise<ProviderResponse> {
559
    if (!this.getApiKey()) {
560
      throw new Error(
561
        'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
562
      );
563
    }
564

565
    const openai = new OpenAI({
566
      apiKey: this.getApiKey(),
567
      organization: this.getOrganization(),
568
      // Unfortunate, but the OpenAI SDK's implementation of base URL is different from how we treat base URL elsewhere.
569
      baseURL: this.getApiUrl(),
570
      maxRetries: 3,
571
      timeout: REQUEST_TIMEOUT_MS,
572
    });
573

574
    const messages = parseChatPrompt(prompt, [
575
      { role: 'user', content: prompt },
576
    ]) as OpenAI.Beta.Threads.ThreadCreateParams.Message[];
577
    const body: OpenAI.Beta.Threads.ThreadCreateAndRunParams = {
578
      assistant_id: this.assistantId,
579
      model: this.assistantConfig.modelName || undefined,
580
      instructions: this.assistantConfig.instructions || undefined,
581
      tools: this.assistantConfig.tools || undefined,
582
      metadata: this.assistantConfig.metadata || undefined,
583
      thread: {
584
        messages,
585
      },
586
    };
587

588
    logger.debug(`Calling OpenAI API, creating thread run: ${JSON.stringify(body)}`);
589
    let run;
590
    try {
591
      run = await openai.beta.threads.createAndRun(body);
592
    } catch (err) {
593
      return failApiCall(err);
594
    }
595

596
    logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
597

598
    while (
599
      run.status === 'in_progress' ||
600
      run.status === 'queued' ||
601
      run.status === 'requires_action'
602
    ) {
603
      if (run.status === 'requires_action') {
604
        const requiredAction = run.required_action;
605
        if (requiredAction === null || requiredAction.type !== 'submit_tool_outputs') {
606
          break;
607
        }
608
        const functionCallsWithCallbacks = requiredAction.submit_tool_outputs.tool_calls.filter(
609
          (toolCall) => {
610
            return (
611
              toolCall.type === 'function' &&
612
              toolCall.function.name in (this.assistantConfig.functionToolCallbacks ?? {})
613
            );
614
          },
615
        );
616
        if (functionCallsWithCallbacks.length === 0) {
617
          break;
618
        }
619
        logger.debug(
620
          `Calling functionToolCallbacks for functions: ${functionCallsWithCallbacks.map(
621
            ({ function: { name } }) => name,
622
          )}`,
623
        );
624
        const toolOutputs = await Promise.all(
625
          functionCallsWithCallbacks.map(async (toolCall) => {
626
            logger.debug(
627
              `Calling functionToolCallbacks[${toolCall.function.name}]('${toolCall.function.arguments}')`,
628
            );
629
            const result = await this.assistantConfig.functionToolCallbacks![
630
              toolCall.function.name
631
            ](toolCall.function.arguments);
632
            return {
633
              tool_call_id: toolCall.id,
634
              output: result,
635
            };
636
          }),
637
        );
638
        logger.debug(
639
          `Calling OpenAI API, submitting tool outputs for ${run.thread_id}: ${JSON.stringify(
640
            toolOutputs,
641
          )}`,
642
        );
643
        try {
644
          run = await openai.beta.threads.runs.submitToolOutputs(run.thread_id, run.id, {
645
            tool_outputs: toolOutputs,
646
          });
647
        } catch (err) {
648
          return failApiCall(err);
649
        }
650
        continue;
651
      }
652

653
      await new Promise((resolve) => setTimeout(resolve, 1000));
654

655
      logger.debug(`Calling OpenAI API, getting thread run ${run.id} status`);
656
      try {
657
        run = await openai.beta.threads.runs.retrieve(run.thread_id, run.id);
658
      } catch (err) {
659
        return failApiCall(err);
660
      }
661
      logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
662
    }
663

664
    if (run.status !== 'completed' && run.status !== 'requires_action') {
665
      if (run.last_error) {
666
        return {
667
          error: `Thread run failed: ${run.last_error.message}`,
668
        };
669
      }
670
      return {
671
        error: `Thread run failed: ${run.status}`,
672
      };
673
    }
674

675
    // Get run steps
676
    logger.debug(`Calling OpenAI API, getting thread run steps for ${run.thread_id}`);
677
    let steps;
678
    try {
679
      steps = await openai.beta.threads.runs.steps.list(run.thread_id, run.id, {
680
        order: 'asc',
681
      });
682
    } catch (err) {
683
      return failApiCall(err);
684
    }
685
    logger.debug(`\tOpenAI thread run steps API response: ${JSON.stringify(steps)}`);
686

687
    const outputBlocks = [];
688
    for (const step of steps.data) {
689
      if (step.step_details.type === 'message_creation') {
690
        logger.debug(`Calling OpenAI API, getting message ${step.id}`);
691
        let message;
692
        try {
693
          message = await openai.beta.threads.messages.retrieve(
694
            run.thread_id,
695
            step.step_details.message_creation.message_id,
696
          );
697
        } catch (err) {
698
          return failApiCall(err);
699
        }
700
        logger.debug(`\tOpenAI thread run step message API response: ${JSON.stringify(message)}`);
701

702
        const content = message.content
703
          .map((content) =>
704
            content.type === 'text' ? content.text.value : `<${content.type} output>`,
705
          )
706
          .join('\n');
707
        outputBlocks.push(`[${toTitleCase(message.role)}] ${content}`);
708
      } else if (step.step_details.type === 'tool_calls') {
709
        for (const toolCall of step.step_details.tool_calls) {
710
          if (toolCall.type === 'function') {
711
            outputBlocks.push(
712
              `[Call function ${toolCall.function.name} with arguments ${toolCall.function.arguments}]`,
713
            );
714
            outputBlocks.push(`[Function output: ${toolCall.function.output}]`);
715
          } else if (toolCall.type === 'retrieval') {
716
            outputBlocks.push(`[Ran retrieval]`);
717
          } else if (toolCall.type === 'code_interpreter') {
718
            const output = toolCall.code_interpreter.outputs
719
              .map((output) => (output.type === 'logs' ? output.logs : `<${output.type} output>`))
720
              .join('\n');
721
            outputBlocks.push(`[Code interpreter input]`);
722
            outputBlocks.push(toolCall.code_interpreter.input);
723
            outputBlocks.push(`[Code interpreter output]`);
724
            outputBlocks.push(output);
725
          } else {
726
            outputBlocks.push(`[Unknown tool call type: ${(toolCall as any).type}]`);
727
          }
728
        }
729
      } else {
730
        outputBlocks.push(`[Unknown step type: ${(step.step_details as any).type}]`);
731
      }
732
    }
733

734
    return {
735
      output: outputBlocks.join('\n\n').trim(),
736
      /*
737
      tokenUsage: {
738
        total: data.usage.total_tokens,
739
        prompt: data.usage.prompt_tokens,
740
        completion: data.usage.completion_tokens,
741
      },
742
      */
743
    };
744
  }
745
}
746

747
type OpenAiImageOptions = OpenAiSharedOptions & {
748
  size?: string;
749
};
750

751
export class OpenAiImageProvider extends OpenAiGenericProvider {
752
  config: OpenAiImageOptions;
753

754
  constructor(
755
    modelName: string,
756
    options: { config?: OpenAiImageOptions; id?: string; env?: EnvOverrides } = {},
757
  ) {
758
    super(modelName, options);
759
    this.config = options.config || {};
760
  }
761

762
  async callApi(
763
    prompt: string,
764
    context?: CallApiContextParams,
765
    callApiOptions?: CallApiOptionsParams,
766
  ): Promise<ProviderResponse> {
767
    const cache = getCache();
768
    const cacheKey = `openai:image:${JSON.stringify({ context, prompt })}`;
769

770
    if (!this.getApiKey()) {
771
      throw new Error(
772
        'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
773
      );
774
    }
775

776
    const openai = new OpenAI({
777
      apiKey: this.getApiKey(),
778
      organization: this.getOrganization(),
779
      // Unfortunate, but the OpenAI SDK's implementation of base URL is different from how we treat base URL elsewhere.
780
      baseURL: this.getApiUrl(),
781
      maxRetries: 3,
782
      timeout: REQUEST_TIMEOUT_MS,
783
    });
784

785
    let response: OpenAI.Images.ImagesResponse | undefined;
786
    let cached = false;
787
    if (isCacheEnabled()) {
788
      // Try to get the cached response
789
      const cachedResponse = await cache.get(cacheKey);
790
      if (cachedResponse) {
791
        logger.debug(`Retrieved cached response for ${prompt}: ${cachedResponse}`);
792
        response = JSON.parse(cachedResponse as string) as OpenAI.Images.ImagesResponse;
793
        cached = true;
794
      }
795
    }
796

797
    if (!response) {
798
      response = await openai.images.generate({
799
        model: this.modelName,
800
        prompt,
801
        n: 1,
802
        size:
803
          ((this.config.size || process.env.OPENAI_IMAGE_SIZE) as
804
            | '1024x1024'
805
            | '256x256'
806
            | '512x512'
807
            | '1792x1024'
808
            | '1024x1792'
809
            | undefined) || '1024x1024',
810
      });
811
    }
812

813
    const url = response.data[0].url;
814
    if (!url) {
815
      return {
816
        error: `No image URL found in response: ${JSON.stringify(response)}`,
817
      };
818
    }
819

820
    if (!cached && isCacheEnabled()) {
821
      try {
822
        await cache.set(cacheKey, JSON.stringify(response));
823
      } catch (err) {
824
        logger.error(`Failed to cache response: ${String(err)}`);
825
      }
826
    }
827

828
    const sanitizedPrompt = prompt
829
      .replace(/\r?\n|\r/g, ' ')
830
      .replace(/\[/g, '(')
831
      .replace(/\]/g, ')');
832
    const ellipsizedPrompt =
833
      sanitizedPrompt.length > 50 ? `${sanitizedPrompt.substring(0, 47)}...` : sanitizedPrompt;
834
    return {
835
      output: `![${ellipsizedPrompt}](${url})`,
836
      cached,
837
    };
838
  }
839
}
840

841
export const DefaultEmbeddingProvider = new OpenAiEmbeddingProvider('text-embedding-3-large');
842
export const DefaultGradingProvider = new OpenAiChatCompletionProvider('gpt-4-0125-preview');
843
export const DefaultGradingJsonProvider = new OpenAiChatCompletionProvider('gpt-4-0125-preview', {
844
  config: {
845
    response_format: { type: 'json_object' },
846
  },
847
});
848
export const DefaultSuggestionsProvider = new OpenAiChatCompletionProvider('gpt-4-0125-preview');
849

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.