promptfoo

Форк
0
/
matchers.ts 
822 строки · 25.6 Кб
1
import invariant from 'tiny-invariant';
2
import logger from './logger';
3
import {
4
  DefaultEmbeddingProvider,
5
  DefaultGradingJsonProvider,
6
  DefaultGradingProvider,
7
} from './providers/openai';
8
import { getNunjucksEngine } from './util';
9
import { loadApiProvider } from './providers';
10
import {
11
  ANSWER_RELEVANCY_GENERATE,
12
  SELECT_BEST_PROMPT,
13
  CONTEXT_FAITHFULNESS_LONGFORM,
14
  CONTEXT_FAITHFULNESS_NLI_STATEMENTS,
15
  CONTEXT_RECALL,
16
  CONTEXT_RECALL_ATTRIBUTED_TOKEN,
17
  CONTEXT_RELEVANCE,
18
  CONTEXT_RELEVANCE_BAD,
19
  DEFAULT_GRADING_PROMPT,
20
  OPENAI_CLOSED_QA_PROMPT,
21
  OPENAI_FACTUALITY_PROMPT,
22
} from './prompts';
23

24
import type {
25
  ApiClassificationProvider,
26
  ApiEmbeddingProvider,
27
  ApiProvider,
28
  ApiSimilarityProvider,
29
  GradingConfig,
30
  GradingResult,
31
  ProviderOptions,
32
  ProviderTypeMap,
33
  TokenUsage,
34
} from './types';
35

36
const nunjucks = getNunjucksEngine();
37

38
function cosineSimilarity(vecA: number[], vecB: number[]) {
39
  if (vecA.length !== vecB.length) {
40
    throw new Error('Vectors must be of equal length');
41
  }
42
  const dotProduct = vecA.reduce((acc, val, idx) => acc + val * vecB[idx], 0);
43
  const vecAMagnitude = Math.sqrt(vecA.reduce((acc, val) => acc + val * val, 0));
44
  const vecBMagnitude = Math.sqrt(vecB.reduce((acc, val) => acc + val * val, 0));
45
  return dotProduct / (vecAMagnitude * vecBMagnitude);
46
}
47

48
function fromVars(vars?: Record<string, string | object>) {
49
  if (!vars) {
50
    return {};
51
  }
52

53
  const ret: Record<string, string> = {};
54
  for (const [key, value] of Object.entries(vars)) {
55
    if (typeof value === 'object') {
56
      ret[key] = JSON.stringify(value);
57
    } else {
58
      ret[key] = value;
59
    }
60
  }
61

62
  return ret;
63
}
64

65
async function loadFromProviderOptions(provider: ProviderOptions) {
66
  invariant(
67
    typeof provider === 'object',
68
    `Provider must be an object, but received a ${typeof provider}: ${provider}`,
69
  );
70
  invariant(
71
    !Array.isArray(provider),
72
    `Provider must be an object, but received an array: ${JSON.stringify(provider)}`,
73
  );
74
  invariant(provider.id, 'Provider supplied to assertion must have an id');
75
  // TODO(ian): set basepath if invoked from filesystem config
76
  return loadApiProvider(provider.id, { options: provider as ProviderOptions });
77
}
78

79
export async function getGradingProvider(
80
  type: 'embedding' | 'classification' | 'text',
81
  provider: GradingConfig['provider'],
82
  defaultProvider: ApiProvider | null,
83
): Promise<ApiProvider | null> {
84
  let finalProvider: ApiProvider | null;
85
  if (typeof provider === 'string') {
86
    // Defined as a string
87
    finalProvider = await loadApiProvider(provider);
88
  } else if (typeof provider === 'object' && typeof (provider as ApiProvider).id === 'function') {
89
    // Defined as an ApiProvider interface
90
    finalProvider = provider as ApiProvider;
91
  } else if (typeof provider === 'object') {
92
    const typeValue = (provider as ProviderTypeMap)[type];
93
    if (typeValue) {
94
      // Defined as embedding, classification, or text record
95
      finalProvider = await getGradingProvider(type, typeValue, defaultProvider);
96
    } else if ((provider as ProviderOptions).id) {
97
      // Defined as ProviderOptions
98
      finalProvider = await loadFromProviderOptions(provider as ProviderOptions);
99
    } else {
100
      throw new Error(
101
        `Invalid provider definition for output type '${type}': ${JSON.stringify(
102
          provider,
103
          null,
104
          2,
105
        )}`,
106
      );
107
    }
108
  } else {
109
    finalProvider = defaultProvider;
110
  }
111
  return finalProvider;
112
}
113

114
export async function getAndCheckProvider(
115
  type: 'embedding' | 'classification' | 'text',
116
  provider: GradingConfig['provider'],
117
  defaultProvider: ApiProvider | null,
118
  checkName: string,
119
): Promise<ApiProvider> {
120
  let matchedProvider = await getGradingProvider(type, provider, defaultProvider);
121
  if (!matchedProvider) {
122
    if (defaultProvider) {
123
      logger.warn(`No provider of type ${type} found for '${checkName}', falling back to default`);
124
      return defaultProvider;
125
    } else {
126
      throw new Error(`No provider of type ${type} found for '${checkName}'`);
127
    }
128
  }
129

130
  let isValidProviderType = true;
131
  if (type === 'embedding') {
132
    isValidProviderType =
133
      'callEmbeddingApi' in matchedProvider || 'callSimilarityApi' in matchedProvider;
134
  } else if (type === 'classification') {
135
    isValidProviderType = 'callClassificationApi' in matchedProvider;
136
  }
137

138
  if (!isValidProviderType) {
139
    if (defaultProvider) {
140
      logger.warn(
141
        `Provider ${matchedProvider.id()} is not a valid ${type} provider for '${checkName}', falling back to default`,
142
      );
143
      return defaultProvider;
144
    } else {
145
      throw new Error(
146
        `Provider ${matchedProvider.id()} is not a valid ${type} provider for '${checkName}'`,
147
      );
148
    }
149
  }
150

151
  return matchedProvider;
152
}
153

154
function fail(reason: string, tokensUsed?: Partial<TokenUsage>): Omit<GradingResult, 'assertion'> {
155
  return {
156
    pass: false,
157
    score: 0,
158
    reason,
159
    tokensUsed: {
160
      total: tokensUsed?.total || 0,
161
      prompt: tokensUsed?.prompt || 0,
162
      completion: tokensUsed?.completion || 0,
163
    },
164
  };
165
}
166

167
export async function matchesSimilarity(
168
  expected: string,
169
  output: string,
170
  threshold: number,
171
  inverse: boolean = false,
172
  grading?: GradingConfig,
173
): Promise<Omit<GradingResult, 'assertion'>> {
174
  let finalProvider = (await getAndCheckProvider(
175
    'embedding',
176
    grading?.provider,
177
    DefaultEmbeddingProvider,
178
    'similarity check',
179
  )) as ApiEmbeddingProvider | ApiSimilarityProvider;
180

181
  let similarity: number;
182
  let tokensUsed: TokenUsage = {
183
    total: 0,
184
    prompt: 0,
185
    completion: 0,
186
  };
187

188
  if ('callSimilarityApi' in finalProvider) {
189
    const similarityResp = await finalProvider.callSimilarityApi(expected, output);
190
    tokensUsed = {
191
      ...tokensUsed,
192
      ...similarityResp.tokenUsage,
193
    };
194
    if (similarityResp.error) {
195
      return fail(similarityResp.error, tokensUsed);
196
    }
197
    if (similarityResp.similarity == null) {
198
      return fail('Unknown error fetching similarity', tokensUsed);
199
    }
200
    similarity = similarityResp.similarity;
201
  } else if ('callEmbeddingApi' in finalProvider) {
202
    const expectedEmbedding = await finalProvider.callEmbeddingApi(expected);
203
    const outputEmbedding = await finalProvider.callEmbeddingApi(output);
204

205
    tokensUsed = {
206
      total: (expectedEmbedding.tokenUsage?.total || 0) + (outputEmbedding.tokenUsage?.total || 0),
207
      prompt:
208
        (expectedEmbedding.tokenUsage?.prompt || 0) + (outputEmbedding.tokenUsage?.prompt || 0),
209
      completion:
210
        (expectedEmbedding.tokenUsage?.completion || 0) +
211
        (outputEmbedding.tokenUsage?.completion || 0),
212
    };
213

214
    if (expectedEmbedding.error || outputEmbedding.error) {
215
      return fail(
216
        expectedEmbedding.error || outputEmbedding.error || 'Unknown error fetching embeddings',
217
        tokensUsed,
218
      );
219
    }
220

221
    if (!expectedEmbedding.embedding || !outputEmbedding.embedding) {
222
      return fail('Embedding not found', tokensUsed);
223
    }
224

225
    similarity = cosineSimilarity(expectedEmbedding.embedding, outputEmbedding.embedding);
226
  } else {
227
    throw new Error('Provider must implement callSimilarityApi or callEmbeddingApi');
228
  }
229
  const pass = inverse ? similarity <= threshold : similarity >= threshold;
230
  const greaterThanReason = `Similarity ${similarity.toFixed(
231
    2,
232
  )} is greater than threshold ${threshold}`;
233
  const lessThanReason = `Similarity ${similarity.toFixed(2)} is less than threshold ${threshold}`;
234
  if (pass) {
235
    return {
236
      pass: true,
237
      score: inverse ? 1 - similarity : similarity,
238
      reason: inverse ? lessThanReason : greaterThanReason,
239
      tokensUsed,
240
    };
241
  }
242
  return {
243
    pass: false,
244
    score: inverse ? 1 - similarity : similarity,
245
    reason: inverse ? greaterThanReason : lessThanReason,
246
    tokensUsed,
247
  };
248
}
249

250
/**
251
 *
252
 * @param expected Expected classification. If undefined, matches any classification.
253
 * @param output Text to classify.
254
 * @param threshold Value between 0 and 1. If the expected classification is undefined, the threshold is the minimum score for any classification. If the expected classification is defined, the threshold is the minimum score for that classification.
255
 * @param grading
256
 * @returns Pass if the output matches the classification with a score greater than or equal to the threshold.
257
 */
258
export async function matchesClassification(
259
  expected: string | undefined,
260
  output: string,
261
  threshold: number,
262
  grading?: GradingConfig,
263
): Promise<Omit<GradingResult, 'assertion'>> {
264
  let finalProvider = (await getAndCheckProvider(
265
    'classification',
266
    grading?.provider,
267
    null,
268
    'classification check',
269
  )) as ApiClassificationProvider;
270

271
  const resp = await finalProvider.callClassificationApi(output);
272

273
  if (!resp.classification) {
274
    return fail(resp.error || 'Unknown error fetching classification');
275
  }
276
  let score;
277
  if (expected === undefined) {
278
    score = Math.max(...Object.values(resp.classification));
279
  } else {
280
    score = resp.classification[expected] || 0;
281
  }
282

283
  if (score >= threshold) {
284
    const reason =
285
      expected === undefined
286
        ? `Maximum classification score ${score.toFixed(2)} >= ${threshold}`
287
        : `Classification ${expected} has score ${score.toFixed(2)} >= ${threshold}`;
288
    return {
289
      pass: true,
290
      score,
291
      reason,
292
    };
293
  }
294
  return {
295
    pass: false,
296
    score,
297
    reason: `Classification ${expected} has score ${score.toFixed(2)} < ${threshold}`,
298
  };
299
}
300

301
export async function matchesLlmRubric(
302
  expected: string,
303
  output: string,
304
  grading?: GradingConfig,
305
  vars?: Record<string, string | object>,
306
): Promise<Omit<GradingResult, 'assertion'>> {
307
  if (!grading) {
308
    throw new Error(
309
      'Cannot grade output without grading config. Specify --grader option or grading config.',
310
    );
311
  }
312

313
  const prompt = nunjucks.renderString(grading.rubricPrompt || DEFAULT_GRADING_PROMPT, {
314
    output: JSON.stringify(output).slice(1, -1),
315
    rubric: JSON.stringify(expected).slice(1, -1),
316
    ...fromVars(vars),
317
  });
318

319
  let finalProvider = await getAndCheckProvider(
320
    'text',
321
    grading.provider,
322
    DefaultGradingJsonProvider,
323
    'llm-rubric check',
324
  );
325
  const resp = await finalProvider.callApi(prompt);
326
  if (resp.error || !resp.output) {
327
    return fail(resp.error || 'No output', resp.tokenUsage);
328
  }
329

330
  invariant(typeof resp.output === 'string', 'llm-rubric produced malformed response');
331
  try {
332
    const firstBrace = resp.output.indexOf('{');
333
    const lastBrace = resp.output.lastIndexOf('}');
334
    const jsonStr = resp.output.substring(firstBrace, lastBrace + 1);
335
    const parsed = JSON.parse(jsonStr) as Partial<GradingResult>;
336
    const pass = parsed.pass ?? (typeof parsed.score === 'undefined' ? true : parsed.score > 0);
337
    return {
338
      pass,
339
      score: parsed.score ?? (pass ? 1.0 : 0.0),
340
      reason: parsed.reason || (pass ? 'Grading passed' : 'Grading failed'),
341
      tokensUsed: {
342
        total: resp.tokenUsage?.total || 0,
343
        prompt: resp.tokenUsage?.prompt || 0,
344
        completion: resp.tokenUsage?.completion || 0,
345
      },
346
    };
347
  } catch (err) {
348
    return fail(`llm-rubric produced malformed response: ${resp.output}`, resp.tokenUsage);
349
  }
350
}
351

352
export async function matchesFactuality(
353
  input: string,
354
  expected: string,
355
  output: string,
356
  grading?: GradingConfig,
357
  vars?: Record<string, string | object>,
358
): Promise<Omit<GradingResult, 'assertion'>> {
359
  if (!grading) {
360
    throw new Error(
361
      'Cannot grade output without grading config. Specify --grader option or grading config.',
362
    );
363
  }
364

365
  const prompt = nunjucks.renderString(grading.rubricPrompt || OPENAI_FACTUALITY_PROMPT, {
366
    input: JSON.stringify(input).slice(1, -1),
367
    ideal: JSON.stringify(expected).slice(1, -1),
368
    completion: JSON.stringify(output).slice(1, -1),
369
    ...fromVars(vars),
370
  });
371

372
  let finalProvider = await getAndCheckProvider(
373
    'text',
374
    grading.provider,
375
    DefaultGradingProvider,
376
    'factuality check',
377
  );
378
  const resp = await finalProvider.callApi(prompt);
379
  if (resp.error || !resp.output) {
380
    return fail(resp.error || 'No output', resp.tokenUsage);
381
  }
382

383
  invariant(typeof resp.output === 'string', 'factuality produced malformed response');
384
  try {
385
    const output = resp.output;
386
    // The preferred output starts like "(A)...", but we also support leading whitespace, lowercase letters, and omitting the first parenthesis.
387
    const answerMatch = output.match(/\s*\(?([a-eA-E])\)/);
388
    if (!answerMatch) {
389
      return fail(
390
        `Factuality checker output did not match expected format: ${output}`,
391
        resp.tokenUsage,
392
      );
393
    }
394
    const option = answerMatch[1].toUpperCase();
395

396
    let reason = '';
397

398
    const scoreLookup: Record<string, number> = {
399
      A: grading.factuality?.subset ?? 1,
400
      B: grading.factuality?.superset ?? 1,
401
      C: grading.factuality?.agree ?? 1,
402
      D: grading.factuality?.disagree ?? 0,
403
      E: grading.factuality?.differButFactual ?? 1,
404
    };
405

406
    // Passing is defined as scores with value >0, and failing as scores with value 0.
407
    const passing = Object.keys(scoreLookup).filter((key) => scoreLookup[key] > 0);
408
    const failing = Object.keys(scoreLookup).filter((key) => scoreLookup[key] === 0);
409

410
    let pass = passing.includes(option) && !failing.includes(option);
411
    const optionReasons: Record<string, string> = {
412
      A: `The submitted answer is a subset of the expert answer and is fully consistent with it.`,
413
      B: `The submitted answer is a superset of the expert answer and is fully consistent with it.`,
414
      C: `The submitted answer contains all the same details as the expert answer.`,
415
      D: `There is a disagreement between the submitted answer and the expert answer.`,
416
      E: `The answers differ, but these differences don't matter from the perspective of factuality.`,
417
    };
418
    if (optionReasons[option]) {
419
      reason = optionReasons[option];
420
    } else {
421
      pass = false;
422
      reason = `Invalid option: ${option}. Full response from factuality checker: ${resp.output}`;
423
    }
424

425
    let score = pass ? 1 : 0;
426
    if (typeof scoreLookup[option] !== 'undefined') {
427
      score = scoreLookup[option];
428
    }
429

430
    return {
431
      pass,
432
      score,
433
      reason,
434
      tokensUsed: {
435
        total: resp.tokenUsage?.total || 0,
436
        prompt: resp.tokenUsage?.prompt || 0,
437
        completion: resp.tokenUsage?.completion || 0,
438
      },
439
    };
440
  } catch (err) {
441
    return fail(`Error parsing output: ${(err as Error).message}`, resp.tokenUsage);
442
  }
443
}
444

445
export async function matchesClosedQa(
446
  input: string,
447
  expected: string,
448
  output: string,
449
  grading?: GradingConfig,
450
  vars?: Record<string, string | object>,
451
): Promise<Omit<GradingResult, 'assertion'>> {
452
  if (!grading) {
453
    throw new Error(
454
      'Cannot grade output without grading config. Specify --grader option or grading config.',
455
    );
456
  }
457

458
  const prompt = nunjucks.renderString(grading.rubricPrompt || OPENAI_CLOSED_QA_PROMPT, {
459
    input: JSON.stringify(input).slice(1, -1),
460
    criteria: JSON.stringify(expected).slice(1, -1),
461
    completion: JSON.stringify(output).slice(1, -1),
462
    ...fromVars(vars),
463
  });
464

465
  let finalProvider = await getAndCheckProvider(
466
    'text',
467
    grading.provider,
468
    DefaultGradingProvider,
469
    'model-graded-closedqa check',
470
  );
471
  const resp = await finalProvider.callApi(prompt);
472
  if (resp.error || !resp.output) {
473
    return fail(resp.error || 'No output', resp.tokenUsage);
474
  }
475

476
  invariant(typeof resp.output === 'string', 'model-graded-closedqa produced malformed response');
477
  try {
478
    const pass = resp.output.endsWith('Y');
479
    let reason;
480
    if (pass) {
481
      reason = 'The submission meets the criterion';
482
    } else if (resp.output.endsWith('N')) {
483
      reason = `The submission does not meet the criterion:\n${resp.output}`;
484
    } else {
485
      reason = `Model grader produced a malformed response:\n${resp.output}`;
486
    }
487
    return {
488
      pass,
489
      score: pass ? 1 : 0,
490
      reason,
491
      tokensUsed: {
492
        total: resp.tokenUsage?.total || 0,
493
        prompt: resp.tokenUsage?.prompt || 0,
494
        completion: resp.tokenUsage?.completion || 0,
495
      },
496
    };
497
  } catch (err) {
498
    return fail(`Error parsing output: ${(err as Error).message}`, resp.tokenUsage);
499
  }
500
}
501

502
export async function matchesAnswerRelevance(
503
  input: string,
504
  output: string,
505
  threshold: number,
506
  grading?: GradingConfig,
507
): Promise<Omit<GradingResult, 'assertion'>> {
508
  let embeddingProvider = await getAndCheckProvider(
509
    'embedding',
510
    grading?.provider,
511
    DefaultEmbeddingProvider,
512
    'answer relevancy check',
513
  );
514
  let textProvider = await getAndCheckProvider(
515
    'text',
516
    grading?.provider,
517
    DefaultGradingProvider,
518
    'answer relevancy check',
519
  );
520

521
  const tokensUsed = {
522
    total: 0,
523
    prompt: 0,
524
    completion: 0,
525
  };
526

527
  const candidateQuestions: string[] = [];
528
  for (let i = 0; i < 3; i++) {
529
    // TODO(ian): Parallelize
530
    const resp = await textProvider.callApi(
531
      JSON.stringify([
532
        ANSWER_RELEVANCY_GENERATE,
533
        {
534
          role: 'user',
535
          content: output,
536
        },
537
      ]),
538
    );
539
    if (resp.error || !resp.output) {
540
      tokensUsed.total += resp.tokenUsage?.total || 0;
541
      tokensUsed.prompt += resp.tokenUsage?.prompt || 0;
542
      tokensUsed.completion += resp.tokenUsage?.completion || 0;
543
      return fail(resp.error || 'No output', tokensUsed);
544
    }
545

546
    invariant(
547
      typeof resp.output === 'string',
548
      'answer relevancy check produced malformed response',
549
    );
550
    candidateQuestions.push(resp.output);
551
  }
552

553
  invariant(
554
    typeof embeddingProvider.callEmbeddingApi === 'function',
555
    `Provider ${embeddingProvider.id} must implement callEmbeddingApi for similarity check`,
556
  );
557

558
  const inputEmbeddingResp = await embeddingProvider.callEmbeddingApi(input);
559
  if (inputEmbeddingResp.error || !inputEmbeddingResp.embedding) {
560
    tokensUsed.total += inputEmbeddingResp.tokenUsage?.total || 0;
561
    tokensUsed.prompt += inputEmbeddingResp.tokenUsage?.prompt || 0;
562
    tokensUsed.completion += inputEmbeddingResp.tokenUsage?.completion || 0;
563
    return fail(inputEmbeddingResp.error || 'No embedding', tokensUsed);
564
  }
565
  const inputEmbedding = inputEmbeddingResp.embedding;
566

567
  const similarities: number[] = [];
568
  for (const question of candidateQuestions) {
569
    const resp = await embeddingProvider.callEmbeddingApi(question);
570
    tokensUsed.total += resp.tokenUsage?.total || 0;
571
    tokensUsed.prompt += resp.tokenUsage?.prompt || 0;
572
    tokensUsed.completion += resp.tokenUsage?.completion || 0;
573
    if (resp.error || !resp.embedding) {
574
      return fail(resp.error || 'No embedding', tokensUsed);
575
    }
576
    similarities.push(cosineSimilarity(inputEmbedding, resp.embedding));
577
  }
578

579
  const similarity = similarities.reduce((a, b) => a + b, 0) / similarities.length;
580
  const pass = similarity >= threshold;
581
  const greaterThanReason = `Relevance ${similarity.toFixed(
582
    2,
583
  )} is greater than threshold ${threshold}`;
584
  const lessThanReason = `Relevance ${similarity.toFixed(2)} is less than threshold ${threshold}`;
585
  if (pass) {
586
    return {
587
      pass: true,
588
      score: similarity,
589
      reason: greaterThanReason,
590
      tokensUsed,
591
    };
592
  }
593
  return {
594
    pass: false,
595
    score: similarity,
596
    reason: lessThanReason,
597
    tokensUsed,
598
  };
599
}
600

601
export async function matchesContextRecall(
602
  context: string,
603
  groundTruth: string,
604
  threshold: number,
605
  grading?: GradingConfig,
606
  vars?: Record<string, string | object>,
607
): Promise<Omit<GradingResult, 'assertion'>> {
608
  let textProvider = await getAndCheckProvider(
609
    'text',
610
    grading?.provider,
611
    DefaultGradingProvider,
612
    'context recall check',
613
  );
614

615
  const promptText = nunjucks.renderString(CONTEXT_RECALL, {
616
    context: JSON.stringify(context).slice(1, -1),
617
    groundTruth: JSON.stringify(groundTruth).slice(1, -1),
618
    ...fromVars(vars),
619
  });
620

621
  const resp = await textProvider.callApi(promptText);
622
  if (resp.error || !resp.output) {
623
    return fail(resp.error || 'No output', resp.tokenUsage);
624
  }
625

626
  invariant(typeof resp.output === 'string', 'context-recall produced malformed response');
627
  const sentences = resp.output.split('\n');
628
  const numerator = sentences.reduce(
629
    (acc, sentence) => acc + (sentence.includes(CONTEXT_RECALL_ATTRIBUTED_TOKEN) ? 1 : 0),
630
    0,
631
  );
632
  const score = numerator / sentences.length;
633
  const pass = score >= threshold;
634
  return {
635
    pass,
636
    score,
637
    reason: pass
638
      ? `Recall ${score.toFixed(2)} is >= ${threshold}`
639
      : `Recall ${score.toFixed(2)} is < ${threshold}`,
640
    tokensUsed: {
641
      total: resp.tokenUsage?.total || 0,
642
      prompt: resp.tokenUsage?.prompt || 0,
643
      completion: resp.tokenUsage?.completion || 0,
644
    },
645
  };
646
}
647

648
export async function matchesContextRelevance(
649
  question: string,
650
  context: string,
651
  threshold: number,
652
  grading?: GradingConfig,
653
): Promise<Omit<GradingResult, 'assertion'>> {
654
  let textProvider = await getAndCheckProvider(
655
    'text',
656
    grading?.provider,
657
    DefaultGradingProvider,
658
    'context relevance check',
659
  );
660

661
  const promptText = nunjucks.renderString(CONTEXT_RELEVANCE, {
662
    context: JSON.stringify(context).slice(1, -1),
663
    query: JSON.stringify(question).slice(1, -1),
664
  });
665

666
  const resp = await textProvider.callApi(promptText);
667
  if (resp.error || !resp.output) {
668
    return fail(resp.error || 'No output', resp.tokenUsage);
669
  }
670

671
  invariant(typeof resp.output === 'string', 'context-relevance produced malformed response');
672
  const sentences = resp.output.split('\n');
673
  const numerator = sentences.reduce(
674
    (acc, sentence) => acc + (sentence.includes(CONTEXT_RELEVANCE_BAD) ? 0 : 1),
675
    0,
676
  );
677
  const score = numerator / sentences.length;
678
  const pass = score >= threshold;
679
  return {
680
    pass,
681
    score,
682
    reason: pass
683
      ? `Relevance ${score.toFixed(2)} is >= ${threshold}`
684
      : `Relevance ${score.toFixed(2)} is < ${threshold}`,
685
    tokensUsed: {
686
      total: resp.tokenUsage?.total || 0,
687
      prompt: resp.tokenUsage?.prompt || 0,
688
      completion: resp.tokenUsage?.completion || 0,
689
    },
690
  };
691
}
692

693
export async function matchesContextFaithfulness(
694
  query: string,
695
  output: string,
696
  context: string,
697
  threshold: number,
698
  grading?: GradingConfig,
699
  vars?: Record<string, string | object>,
700
): Promise<Omit<GradingResult, 'assertion'>> {
701
  let textProvider = await getAndCheckProvider(
702
    'text',
703
    grading?.provider,
704
    DefaultGradingProvider,
705
    'faithfulness check',
706
  );
707

708
  let promptText = nunjucks.renderString(CONTEXT_FAITHFULNESS_LONGFORM, {
709
    question: JSON.stringify(query).slice(1, -1),
710
    answer: JSON.stringify(output).slice(1, -1),
711
    ...fromVars(vars),
712
  });
713

714
  let resp = await textProvider.callApi(promptText);
715
  if (resp.error || !resp.output) {
716
    return fail(resp.error || 'No output', resp.tokenUsage);
717
  }
718

719
  invariant(typeof resp.output === 'string', 'context-faithfulness produced malformed response');
720

721
  let statements = resp.output.split('\n');
722
  promptText = nunjucks.renderString(CONTEXT_FAITHFULNESS_NLI_STATEMENTS, {
723
    context: JSON.stringify(context).slice(1, -1),
724
    statements: JSON.stringify(statements.join('\n')).slice(1, -1),
725
    ...fromVars(vars),
726
  });
727

728
  resp = await textProvider.callApi(promptText);
729
  if (resp.error || !resp.output) {
730
    return fail(resp.error || 'No output', resp.tokenUsage);
731
  }
732

733
  invariant(typeof resp.output === 'string', 'context-faithfulness produced malformed response');
734

735
  let finalAnswer = 'Final verdict for each statement in order:';
736
  finalAnswer = finalAnswer.toLowerCase();
737
  let verdicts = resp.output.toLowerCase().trim();
738
  let score: number;
739
  if (verdicts.includes(finalAnswer)) {
740
    verdicts = verdicts.slice(verdicts.indexOf(finalAnswer) + finalAnswer.length);
741
    score =
742
      verdicts.split('.').filter((answer) => answer.trim() !== '' && !answer.includes('yes'))
743
        .length / statements.length;
744
  } else {
745
    score = (verdicts.split('verdict: no').length - 1) / statements.length;
746
  }
747
  score = 1 - score;
748
  let pass = score >= threshold;
749
  return {
750
    pass,
751
    score,
752
    reason: pass
753
      ? `Faithfulness ${score.toFixed(2)} is >= ${threshold}`
754
      : `Faithfulness ${score.toFixed(2)} is < ${threshold}`,
755
    tokensUsed: {
756
      total: resp.tokenUsage?.total || 0,
757
      prompt: resp.tokenUsage?.prompt || 0,
758
      completion: resp.tokenUsage?.completion || 0,
759
    },
760
  };
761
}
762

763
export async function matchesSelectBest(
764
  criteria: string,
765
  outputs: string[],
766
  grading?: GradingConfig,
767
  vars?: Record<string, string | object>,
768
): Promise<Omit<GradingResult, 'assertion'>[]> {
769
  invariant(
770
    outputs.length >= 2,
771
    'select-best assertion must have at least two outputs to compare between',
772
  );
773
  let textProvider = await getAndCheckProvider(
774
    'text',
775
    grading?.provider,
776
    DefaultGradingProvider,
777
    'select-best check',
778
  );
779

780
  let promptText = nunjucks.renderString(grading?.rubricPrompt || SELECT_BEST_PROMPT, {
781
    criteria: JSON.stringify(criteria).slice(1, -1),
782
    outputs: outputs.map((output) => JSON.stringify(output).slice(1, -1)),
783
    ...fromVars(vars),
784
  });
785

786
  let resp = await textProvider.callApi(promptText);
787
  if (resp.error || !resp.output) {
788
    return new Array(outputs.length).fill(fail(resp.error || 'No output', resp.tokenUsage));
789
  }
790

791
  invariant(typeof resp.output === 'string', 'select-best produced malformed response');
792

793
  const firstDigitMatch = resp.output.trim().match(/\d/);
794
  const verdict = firstDigitMatch ? parseInt(firstDigitMatch[0], 10) : NaN;
795

796
  if (isNaN(verdict) || verdict < 0 || verdict >= outputs.length) {
797
    return new Array(outputs.length).fill(fail(`Invalid select-best verdict: ${verdict}`));
798
  }
799

800
  const tokensUsed = {
801
    total: resp.tokenUsage?.total || 0,
802
    prompt: resp.tokenUsage?.prompt || 0,
803
    completion: resp.tokenUsage?.completion || 0,
804
  };
805
  return outputs.map((output, index) => {
806
    if (index === verdict) {
807
      return {
808
        pass: true,
809
        score: 1,
810
        reason: `Output selected as the best: ${criteria}`,
811
        tokensUsed,
812
      };
813
    } else {
814
      return {
815
        pass: false,
816
        score: 0,
817
        reason: `Output not selected: ${criteria}`,
818
        tokensUsed,
819
      };
820
    }
821
  });
822
}
823

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.