1
import OpenAI from 'openai';
3
import logger from '../logger';
4
import { fetchWithCache, getCache, isCacheEnabled } from '../cache';
5
import { REQUEST_TIMEOUT_MS, parseChatPrompt, toTitleCase } from './shared';
6
import { OpenAiFunction, OpenAiTool } from './openaiUtil';
13
ProviderEmbeddingResponse,
18
interface OpenAiSharedOptions {
23
organization?: string;
27
type OpenAiCompletionOptions = OpenAiSharedOptions & {
31
frequency_penalty?: number;
32
presence_penalty?: number;
34
functions?: OpenAiFunction[];
35
function_call?: 'none' | 'auto' | { name: string };
37
tool_choice?: 'none' | 'auto' | { type: 'function'; function?: { name: string } };
38
response_format?: { type: 'json_object' };
44
function failApiCall(err: any) {
45
if (err instanceof OpenAI.APIError) {
47
error: `API error: ${err.type} ${err.message}`,
51
error: `API error: ${String(err)}`,
55
function getTokenUsage(data: any, cached: boolean): Partial<TokenUsage> {
58
return { cached: data.usage.total_tokens, total: data.usage.total_tokens };
61
total: data.usage.total_tokens,
62
prompt: data.usage.prompt_tokens || 0,
63
completion: data.usage.completion_tokens || 0,
70
export class OpenAiGenericProvider implements ApiProvider {
73
config: OpenAiSharedOptions;
78
options: { config?: OpenAiSharedOptions; id?: string; env?: EnvOverrides } = {},
80
const { config, id, env } = options;
82
this.modelName = modelName;
83
this.config = config || {};
84
this.id = id ? () => id : this.id;
88
return `openai:${this.modelName}`;
92
return `[OpenAI Provider ${this.modelName}]`;
95
getOrganization(): string | undefined {
97
this.config.organization || this.env?.OPENAI_ORGANIZATION || process.env.OPENAI_ORGANIZATION
101
getApiUrlDefault(): string {
102
return 'https://api.openai.com/v1';
105
getApiUrl(): string {
106
const apiHost = this.config.apiHost || this.env?.OPENAI_API_HOST || process.env.OPENAI_API_HOST;
108
return `https://${apiHost}/v1`;
111
this.config.apiBaseUrl ||
112
this.env?.OPENAI_API_BASE_URL ||
113
process.env.OPENAI_API_BASE_URL ||
114
this.getApiUrlDefault()
118
getApiKey(): string | undefined {
120
this.config.apiKey ||
121
(this.config?.apiKeyEnvar
122
? process.env[this.config.apiKeyEnvar] ||
123
this.env?.[this.config.apiKeyEnvar as keyof EnvOverrides]
125
this.env?.OPENAI_API_KEY ||
126
process.env.OPENAI_API_KEY
133
context?: CallApiContextParams,
134
callApiOptions?: CallApiOptionsParams,
135
): Promise<ProviderResponse> {
136
throw new Error('Not implemented');
140
export class OpenAiEmbeddingProvider extends OpenAiGenericProvider {
141
async callEmbeddingApi(text: string): Promise<ProviderEmbeddingResponse> {
142
if (!this.getApiKey()) {
143
throw new Error('OpenAI API key must be set for similarity comparison');
148
model: this.modelName,
153
({ data, cached } = (await fetchWithCache(
154
`${this.getApiUrl()}/embeddings`,
158
'Content-Type': 'application/json',
159
Authorization: `Bearer ${this.getApiKey()}`,
160
...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
162
body: JSON.stringify(body),
165
)) as unknown as any);
167
logger.error(`API call error: ${err}`);
170
logger.debug(`\tOpenAI embeddings API response: ${JSON.stringify(data)}`);
173
const embedding = data?.data?.[0]?.embedding;
175
throw new Error('No embedding found in OpenAI embeddings API response');
179
tokenUsage: getTokenUsage(data, cached),
182
logger.error(data.error.message);
188
export class OpenAiCompletionProvider extends OpenAiGenericProvider {
189
static OPENAI_COMPLETION_MODELS = [
191
id: 'gpt-3.5-turbo-instruct',
193
input: 0.0015 / 1000,
194
output: 0.002 / 1000,
198
id: 'gpt-3.5-turbo-instruct-0914',
200
input: 0.0015 / 1000,
201
output: 0.002 / 1000,
205
id: 'text-davinci-003',
208
id: 'text-davinci-002',
211
id: 'text-curie-001',
214
id: 'text-babbage-001',
221
static OPENAI_COMPLETION_MODEL_NAMES = OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS.map(
225
config: OpenAiCompletionOptions;
229
options: { config?: OpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {},
231
super(modelName, options);
232
this.config = options.config || {};
234
!OpenAiCompletionProvider.OPENAI_COMPLETION_MODEL_NAMES.includes(modelName) &&
235
this.getApiUrl() === this.getApiUrlDefault()
237
logger.warn(`FYI: Using unknown OpenAI completion model: ${modelName}`);
243
context?: CallApiContextParams,
244
callApiOptions?: CallApiOptionsParams,
245
): Promise<ProviderResponse> {
246
if (!this.getApiKey()) {
248
'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
254
stop = process.env.OPENAI_STOP
255
? JSON.parse(process.env.OPENAI_STOP)
256
: this.config?.stop || ['<|im_end|>', '<|endoftext|>'];
258
throw new Error(`OPENAI_STOP is not a valid JSON string: ${err}`);
261
model: this.modelName,
263
seed: this.config.seed || 0,
264
max_tokens: this.config.max_tokens ?? parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
265
temperature: this.config.temperature ?? parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
266
top_p: this.config.top_p ?? parseFloat(process.env.OPENAI_TOP_P || '1'),
268
this.config.presence_penalty ?? parseFloat(process.env.OPENAI_PRESENCE_PENALTY || '0'),
270
this.config.frequency_penalty ?? parseFloat(process.env.OPENAI_FREQUENCY_PENALTY || '0'),
271
best_of: this.config.best_of ?? parseInt(process.env.OPENAI_BEST_OF || '1'),
272
...(callApiOptions?.includeLogProbs ? { logprobs: callApiOptions.includeLogProbs } : {}),
273
...(stop ? { stop } : {}),
274
...(this.config.passthrough || {}),
276
logger.debug(`Calling OpenAI API: ${JSON.stringify(body)}`);
280
({ data, cached } = (await fetchWithCache(
281
`${this.getApiUrl()}/completions`,
285
'Content-Type': 'application/json',
286
Authorization: `Bearer ${this.getApiKey()}`,
287
...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
289
body: JSON.stringify(body),
292
)) as unknown as any);
295
error: `API call error: ${String(err)}`,
298
logger.debug(`\tOpenAI completions API response: ${JSON.stringify(data)}`);
301
error: formatOpenAiError(data),
306
output: data.choices[0].text,
307
tokenUsage: getTokenUsage(data, cached),
312
data.usage?.prompt_tokens,
313
data.usage?.completion_tokens,
318
error: `API error: ${String(err)}: ${JSON.stringify(data)}`,
324
export class OpenAiChatCompletionProvider extends OpenAiGenericProvider {
325
static OPENAI_CHAT_MODELS = [
326
...['gpt-4', 'gpt-4-0314', 'gpt-4-0613'].map((model) => ({
334
'gpt-4-1106-preview',
335
'gpt-4-1106-vision-preview',
336
'gpt-4-0125-preview',
337
'gpt-4-turbo-preview',
345
...['gpt-4-32k', 'gpt-4-32k-0314'].map((model) => ({
354
'gpt-3.5-turbo-0301',
355
'gpt-3.5-turbo-0613',
356
'gpt-3.5-turbo-1106',
357
'gpt-3.5-turbo-0125',
359
'gpt-3.5-turbo-16k-0613',
363
input: 0.0005 / 1000,
364
output: 0.0015 / 1000,
369
static OPENAI_CHAT_MODEL_NAMES = OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS.map(
373
config: OpenAiCompletionOptions;
377
options: { config?: OpenAiCompletionOptions; id?: string; env?: EnvOverrides } = {},
379
if (!OpenAiChatCompletionProvider.OPENAI_CHAT_MODEL_NAMES.includes(modelName)) {
380
logger.warn(`Using unknown OpenAI chat model: ${modelName}`);
382
super(modelName, options);
383
this.config = options.config || {};
388
context?: CallApiContextParams,
389
callApiOptions?: CallApiOptionsParams,
390
): Promise<ProviderResponse> {
391
if (!this.getApiKey()) {
393
'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
397
const messages = parseChatPrompt(prompt, [{ role: 'user', content: prompt }]);
401
stop = process.env.OPENAI_STOP
402
? JSON.parse(process.env.OPENAI_STOP)
403
: this.config?.stop || [];
405
throw new Error(`OPENAI_STOP is not a valid JSON string: ${err}`);
408
model: this.modelName,
410
seed: this.config.seed || 0,
411
max_tokens: this.config.max_tokens ?? parseInt(process.env.OPENAI_MAX_TOKENS || '1024'),
412
temperature: this.config.temperature ?? parseFloat(process.env.OPENAI_TEMPERATURE || '0'),
413
top_p: this.config.top_p ?? parseFloat(process.env.OPENAI_TOP_P || '1'),
415
this.config.presence_penalty ?? parseFloat(process.env.OPENAI_PRESENCE_PENALTY || '0'),
417
this.config.frequency_penalty ?? parseFloat(process.env.OPENAI_FREQUENCY_PENALTY || '0'),
418
...(this.config.functions ? { functions: this.config.functions } : {}),
419
...(this.config.function_call ? { function_call: this.config.function_call } : {}),
420
...(this.config.tools ? { tools: this.config.tools } : {}),
421
...(this.config.tool_choice ? { tool_choice: this.config.tool_choice } : {}),
422
...(this.config.response_format ? { response_format: this.config.response_format } : {}),
423
...(callApiOptions?.includeLogProbs ? { logprobs: callApiOptions.includeLogProbs } : {}),
424
...(this.config.stop ? { stop: this.config.stop } : {}),
425
...(this.config.passthrough || {}),
427
logger.debug(`Calling OpenAI API: ${JSON.stringify(body)}`);
432
({ data, cached } = (await fetchWithCache(
433
`${this.getApiUrl()}/chat/completions`,
437
'Content-Type': 'application/json',
438
Authorization: `Bearer ${this.getApiKey()}`,
439
...(this.getOrganization() ? { 'OpenAI-Organization': this.getOrganization() } : {}),
441
body: JSON.stringify(body),
444
)) as unknown as { data: any; cached: boolean });
447
error: `API call error: ${String(err)}`,
451
logger.debug(`\tOpenAI chat completions API response: ${JSON.stringify(data)}`);
454
error: formatOpenAiError(data),
458
const message = data.choices[0].message;
460
message.content === null ? message.function_call || message.tool_calls : message.content;
461
const logProbs = data.choices[0].logprobs?.content?.map(
462
(logProbObj: { token: string; logprob: number }) => logProbObj.logprob,
467
tokenUsage: getTokenUsage(data, cached),
473
data.usage?.prompt_tokens,
474
data.usage?.completion_tokens,
479
error: `API error: ${String(err)}: ${JSON.stringify(data)}`,
485
function formatOpenAiError(data: { error: { message: string; type?: string; code?: string } }) {
487
`API error: ${data.error.message}` +
488
(data.error.type ? `, Type: ${data.error.type}` : '') +
489
(data.error.code ? `, Code: ${data.error.code}` : '')
493
function calculateCost(
495
config: OpenAiSharedOptions,
496
promptTokens?: number,
497
completionTokens?: number,
498
): number | undefined {
499
if (!promptTokens || !completionTokens) {
504
...OpenAiChatCompletionProvider.OPENAI_CHAT_MODELS,
505
...OpenAiCompletionProvider.OPENAI_COMPLETION_MODELS,
506
].find((m) => m.id === modelName);
507
if (!model || !model.cost) {
511
const inputCost = config.cost ?? model.cost.input;
512
const outputCost = config.cost ?? model.cost.output;
513
return inputCost * promptTokens + outputCost * completionTokens || undefined;
516
interface AssistantMessagesResponseDataContent {
523
interface AssistantMessagesResponseData {
526
content?: AssistantMessagesResponseDataContent[];
530
type OpenAiAssistantOptions = OpenAiSharedOptions & {
532
instructions?: string;
533
tools?: OpenAI.Beta.Threads.ThreadCreateAndRunParams['tools'];
538
functionToolCallbacks?: Record<
539
OpenAI.FunctionDefinition['name'],
540
(arg: string) => Promise<string>
545
export class OpenAiAssistantProvider extends OpenAiGenericProvider {
547
assistantConfig: OpenAiAssistantOptions;
551
options: { config?: OpenAiAssistantOptions; id?: string; env?: EnvOverrides } = {},
553
super(assistantId, options);
554
this.assistantConfig = options.config || {};
555
this.assistantId = assistantId;
558
async callApi(prompt: string): Promise<ProviderResponse> {
559
if (!this.getApiKey()) {
561
'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
565
const openai = new OpenAI({
566
apiKey: this.getApiKey(),
567
organization: this.getOrganization(),
569
baseURL: this.getApiUrl(),
571
timeout: REQUEST_TIMEOUT_MS,
574
const messages = parseChatPrompt(prompt, [
575
{ role: 'user', content: prompt },
576
]) as OpenAI.Beta.Threads.ThreadCreateParams.Message[];
577
const body: OpenAI.Beta.Threads.ThreadCreateAndRunParams = {
578
assistant_id: this.assistantId,
579
model: this.assistantConfig.modelName || undefined,
580
instructions: this.assistantConfig.instructions || undefined,
581
tools: this.assistantConfig.tools || undefined,
582
metadata: this.assistantConfig.metadata || undefined,
588
logger.debug(`Calling OpenAI API, creating thread run: ${JSON.stringify(body)}`);
591
run = await openai.beta.threads.createAndRun(body);
593
return failApiCall(err);
596
logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
599
run.status === 'in_progress' ||
600
run.status === 'queued' ||
601
run.status === 'requires_action'
603
if (run.status === 'requires_action') {
604
const requiredAction = run.required_action;
605
if (requiredAction === null || requiredAction.type !== 'submit_tool_outputs') {
608
const functionCallsWithCallbacks = requiredAction.submit_tool_outputs.tool_calls.filter(
611
toolCall.type === 'function' &&
612
toolCall.function.name in (this.assistantConfig.functionToolCallbacks ?? {})
616
if (functionCallsWithCallbacks.length === 0) {
620
`Calling functionToolCallbacks for functions: ${functionCallsWithCallbacks.map(
621
({ function: { name } }) => name,
624
const toolOutputs = await Promise.all(
625
functionCallsWithCallbacks.map(async (toolCall) => {
627
`Calling functionToolCallbacks[${toolCall.function.name}]('${toolCall.function.arguments}')`,
629
const result = await this.assistantConfig.functionToolCallbacks![
630
toolCall.function.name
631
](toolCall.function.arguments);
633
tool_call_id: toolCall.id,
639
`Calling OpenAI API, submitting tool outputs for ${run.thread_id}: ${JSON.stringify(
644
run = await openai.beta.threads.runs.submitToolOutputs(run.thread_id, run.id, {
645
tool_outputs: toolOutputs,
648
return failApiCall(err);
653
await new Promise((resolve) => setTimeout(resolve, 1000));
655
logger.debug(`Calling OpenAI API, getting thread run ${run.id} status`);
657
run = await openai.beta.threads.runs.retrieve(run.thread_id, run.id);
659
return failApiCall(err);
661
logger.debug(`\tOpenAI thread run API response: ${JSON.stringify(run)}`);
664
if (run.status !== 'completed' && run.status !== 'requires_action') {
665
if (run.last_error) {
667
error: `Thread run failed: ${run.last_error.message}`,
671
error: `Thread run failed: ${run.status}`,
676
logger.debug(`Calling OpenAI API, getting thread run steps for ${run.thread_id}`);
679
steps = await openai.beta.threads.runs.steps.list(run.thread_id, run.id, {
683
return failApiCall(err);
685
logger.debug(`\tOpenAI thread run steps API response: ${JSON.stringify(steps)}`);
687
const outputBlocks = [];
688
for (const step of steps.data) {
689
if (step.step_details.type === 'message_creation') {
690
logger.debug(`Calling OpenAI API, getting message ${step.id}`);
693
message = await openai.beta.threads.messages.retrieve(
695
step.step_details.message_creation.message_id,
698
return failApiCall(err);
700
logger.debug(`\tOpenAI thread run step message API response: ${JSON.stringify(message)}`);
702
const content = message.content
704
content.type === 'text' ? content.text.value : `<${content.type} output>`,
707
outputBlocks.push(`[${toTitleCase(message.role)}] ${content}`);
708
} else if (step.step_details.type === 'tool_calls') {
709
for (const toolCall of step.step_details.tool_calls) {
710
if (toolCall.type === 'function') {
712
`[Call function ${toolCall.function.name} with arguments ${toolCall.function.arguments}]`,
714
outputBlocks.push(`[Function output: ${toolCall.function.output}]`);
715
} else if (toolCall.type === 'retrieval') {
716
outputBlocks.push(`[Ran retrieval]`);
717
} else if (toolCall.type === 'code_interpreter') {
718
const output = toolCall.code_interpreter.outputs
719
.map((output) => (output.type === 'logs' ? output.logs : `<${output.type} output>`))
721
outputBlocks.push(`[Code interpreter input]`);
722
outputBlocks.push(toolCall.code_interpreter.input);
723
outputBlocks.push(`[Code interpreter output]`);
724
outputBlocks.push(output);
726
outputBlocks.push(`[Unknown tool call type: ${(toolCall as any).type}]`);
730
outputBlocks.push(`[Unknown step type: ${(step.step_details as any).type}]`);
735
output: outputBlocks.join('\n\n').trim(),
747
type OpenAiImageOptions = OpenAiSharedOptions & {
751
export class OpenAiImageProvider extends OpenAiGenericProvider {
752
config: OpenAiImageOptions;
756
options: { config?: OpenAiImageOptions; id?: string; env?: EnvOverrides } = {},
758
super(modelName, options);
759
this.config = options.config || {};
764
context?: CallApiContextParams,
765
callApiOptions?: CallApiOptionsParams,
766
): Promise<ProviderResponse> {
767
const cache = getCache();
768
const cacheKey = `openai:image:${JSON.stringify({ context, prompt })}`;
770
if (!this.getApiKey()) {
772
'OpenAI API key is not set. Set the OPENAI_API_KEY environment variable or add `apiKey` to the provider config.',
776
const openai = new OpenAI({
777
apiKey: this.getApiKey(),
778
organization: this.getOrganization(),
780
baseURL: this.getApiUrl(),
782
timeout: REQUEST_TIMEOUT_MS,
785
let response: OpenAI.Images.ImagesResponse | undefined;
787
if (isCacheEnabled()) {
789
const cachedResponse = await cache.get(cacheKey);
790
if (cachedResponse) {
791
logger.debug(`Retrieved cached response for ${prompt}: ${cachedResponse}`);
792
response = JSON.parse(cachedResponse as string) as OpenAI.Images.ImagesResponse;
798
response = await openai.images.generate({
799
model: this.modelName,
803
((this.config.size || process.env.OPENAI_IMAGE_SIZE) as
809
| undefined) || '1024x1024',
813
const url = response.data[0].url;
816
error: `No image URL found in response: ${JSON.stringify(response)}`,
820
if (!cached && isCacheEnabled()) {
822
await cache.set(cacheKey, JSON.stringify(response));
824
logger.error(`Failed to cache response: ${String(err)}`);
828
const sanitizedPrompt = prompt
829
.replace(/\r?\n|\r/g, ' ')
831
.replace(/\]/g, ')');
832
const ellipsizedPrompt =
833
sanitizedPrompt.length > 50 ? `${sanitizedPrompt.substring(0, 47)}...` : sanitizedPrompt;
835
output: `![${ellipsizedPrompt}](${url})`,
841
export const DefaultEmbeddingProvider = new OpenAiEmbeddingProvider('text-embedding-3-large');
842
export const DefaultGradingProvider = new OpenAiChatCompletionProvider('gpt-4-0125-preview');
843
export const DefaultGradingJsonProvider = new OpenAiChatCompletionProvider('gpt-4-0125-preview', {
845
response_format: { type: 'json_object' },
848
export const DefaultSuggestionsProvider = new OpenAiChatCompletionProvider('gpt-4-0125-preview');