1
import invariant from 'tiny-invariant';
2
import logger from './logger';
4
DefaultEmbeddingProvider,
5
DefaultGradingJsonProvider,
6
DefaultGradingProvider,
7
} from './providers/openai';
8
import { getNunjucksEngine } from './util';
9
import { loadApiProvider } from './providers';
11
ANSWER_RELEVANCY_GENERATE,
13
CONTEXT_FAITHFULNESS_LONGFORM,
14
CONTEXT_FAITHFULNESS_NLI_STATEMENTS,
16
CONTEXT_RECALL_ATTRIBUTED_TOKEN,
18
CONTEXT_RELEVANCE_BAD,
19
DEFAULT_GRADING_PROMPT,
20
OPENAI_CLOSED_QA_PROMPT,
21
OPENAI_FACTUALITY_PROMPT,
25
ApiClassificationProvider,
28
ApiSimilarityProvider,
36
const nunjucks = getNunjucksEngine();
38
function cosineSimilarity(vecA: number[], vecB: number[]) {
39
if (vecA.length !== vecB.length) {
40
throw new Error('Vectors must be of equal length');
42
const dotProduct = vecA.reduce((acc, val, idx) => acc + val * vecB[idx], 0);
43
const vecAMagnitude = Math.sqrt(vecA.reduce((acc, val) => acc + val * val, 0));
44
const vecBMagnitude = Math.sqrt(vecB.reduce((acc, val) => acc + val * val, 0));
45
return dotProduct / (vecAMagnitude * vecBMagnitude);
48
function fromVars(vars?: Record<string, string | object>) {
53
const ret: Record<string, string> = {};
54
for (const [key, value] of Object.entries(vars)) {
55
if (typeof value === 'object') {
56
ret[key] = JSON.stringify(value);
65
async function loadFromProviderOptions(provider: ProviderOptions) {
67
typeof provider === 'object',
68
`Provider must be an object, but received a ${typeof provider}: ${provider}`,
71
!Array.isArray(provider),
72
`Provider must be an object, but received an array: ${JSON.stringify(provider)}`,
74
invariant(provider.id, 'Provider supplied to assertion must have an id');
76
return loadApiProvider(provider.id, { options: provider as ProviderOptions });
79
export async function getGradingProvider(
80
type: 'embedding' | 'classification' | 'text',
81
provider: GradingConfig['provider'],
82
defaultProvider: ApiProvider | null,
83
): Promise<ApiProvider | null> {
84
let finalProvider: ApiProvider | null;
85
if (typeof provider === 'string') {
87
finalProvider = await loadApiProvider(provider);
88
} else if (typeof provider === 'object' && typeof (provider as ApiProvider).id === 'function') {
90
finalProvider = provider as ApiProvider;
91
} else if (typeof provider === 'object') {
92
const typeValue = (provider as ProviderTypeMap)[type];
95
finalProvider = await getGradingProvider(type, typeValue, defaultProvider);
96
} else if ((provider as ProviderOptions).id) {
98
finalProvider = await loadFromProviderOptions(provider as ProviderOptions);
101
`Invalid provider definition for output type '${type}': ${JSON.stringify(
109
finalProvider = defaultProvider;
111
return finalProvider;
114
export async function getAndCheckProvider(
115
type: 'embedding' | 'classification' | 'text',
116
provider: GradingConfig['provider'],
117
defaultProvider: ApiProvider | null,
119
): Promise<ApiProvider> {
120
let matchedProvider = await getGradingProvider(type, provider, defaultProvider);
121
if (!matchedProvider) {
122
if (defaultProvider) {
123
logger.warn(`No provider of type ${type} found for '${checkName}', falling back to default`);
124
return defaultProvider;
126
throw new Error(`No provider of type ${type} found for '${checkName}'`);
130
let isValidProviderType = true;
131
if (type === 'embedding') {
132
isValidProviderType =
133
'callEmbeddingApi' in matchedProvider || 'callSimilarityApi' in matchedProvider;
134
} else if (type === 'classification') {
135
isValidProviderType = 'callClassificationApi' in matchedProvider;
138
if (!isValidProviderType) {
139
if (defaultProvider) {
141
`Provider ${matchedProvider.id()} is not a valid ${type} provider for '${checkName}', falling back to default`,
143
return defaultProvider;
146
`Provider ${matchedProvider.id()} is not a valid ${type} provider for '${checkName}'`,
151
return matchedProvider;
154
function fail(reason: string, tokensUsed?: Partial<TokenUsage>): Omit<GradingResult, 'assertion'> {
160
total: tokensUsed?.total || 0,
161
prompt: tokensUsed?.prompt || 0,
162
completion: tokensUsed?.completion || 0,
167
export async function matchesSimilarity(
171
inverse: boolean = false,
172
grading?: GradingConfig,
173
): Promise<Omit<GradingResult, 'assertion'>> {
174
let finalProvider = (await getAndCheckProvider(
177
DefaultEmbeddingProvider,
179
)) as ApiEmbeddingProvider | ApiSimilarityProvider;
181
let similarity: number;
182
let tokensUsed: TokenUsage = {
188
if ('callSimilarityApi' in finalProvider) {
189
const similarityResp = await finalProvider.callSimilarityApi(expected, output);
192
...similarityResp.tokenUsage,
194
if (similarityResp.error) {
195
return fail(similarityResp.error, tokensUsed);
197
if (similarityResp.similarity == null) {
198
return fail('Unknown error fetching similarity', tokensUsed);
200
similarity = similarityResp.similarity;
201
} else if ('callEmbeddingApi' in finalProvider) {
202
const expectedEmbedding = await finalProvider.callEmbeddingApi(expected);
203
const outputEmbedding = await finalProvider.callEmbeddingApi(output);
206
total: (expectedEmbedding.tokenUsage?.total || 0) + (outputEmbedding.tokenUsage?.total || 0),
208
(expectedEmbedding.tokenUsage?.prompt || 0) + (outputEmbedding.tokenUsage?.prompt || 0),
210
(expectedEmbedding.tokenUsage?.completion || 0) +
211
(outputEmbedding.tokenUsage?.completion || 0),
214
if (expectedEmbedding.error || outputEmbedding.error) {
216
expectedEmbedding.error || outputEmbedding.error || 'Unknown error fetching embeddings',
221
if (!expectedEmbedding.embedding || !outputEmbedding.embedding) {
222
return fail('Embedding not found', tokensUsed);
225
similarity = cosineSimilarity(expectedEmbedding.embedding, outputEmbedding.embedding);
227
throw new Error('Provider must implement callSimilarityApi or callEmbeddingApi');
229
const pass = inverse ? similarity <= threshold : similarity >= threshold;
230
const greaterThanReason = `Similarity ${similarity.toFixed(
232
)} is greater than threshold ${threshold}`;
233
const lessThanReason = `Similarity ${similarity.toFixed(2)} is less than threshold ${threshold}`;
237
score: inverse ? 1 - similarity : similarity,
238
reason: inverse ? lessThanReason : greaterThanReason,
244
score: inverse ? 1 - similarity : similarity,
245
reason: inverse ? greaterThanReason : lessThanReason,
258
export async function matchesClassification(
259
expected: string | undefined,
262
grading?: GradingConfig,
263
): Promise<Omit<GradingResult, 'assertion'>> {
264
let finalProvider = (await getAndCheckProvider(
268
'classification check',
269
)) as ApiClassificationProvider;
271
const resp = await finalProvider.callClassificationApi(output);
273
if (!resp.classification) {
274
return fail(resp.error || 'Unknown error fetching classification');
277
if (expected === undefined) {
278
score = Math.max(...Object.values(resp.classification));
280
score = resp.classification[expected] || 0;
283
if (score >= threshold) {
285
expected === undefined
286
? `Maximum classification score ${score.toFixed(2)} >= ${threshold}`
287
: `Classification ${expected} has score ${score.toFixed(2)} >= ${threshold}`;
297
reason: `Classification ${expected} has score ${score.toFixed(2)} < ${threshold}`,
301
export async function matchesLlmRubric(
304
grading?: GradingConfig,
305
vars?: Record<string, string | object>,
306
): Promise<Omit<GradingResult, 'assertion'>> {
309
'Cannot grade output without grading config. Specify --grader option or grading config.',
313
const prompt = nunjucks.renderString(grading.rubricPrompt || DEFAULT_GRADING_PROMPT, {
314
output: JSON.stringify(output).slice(1, -1),
315
rubric: JSON.stringify(expected).slice(1, -1),
319
let finalProvider = await getAndCheckProvider(
322
DefaultGradingJsonProvider,
325
const resp = await finalProvider.callApi(prompt);
326
if (resp.error || !resp.output) {
327
return fail(resp.error || 'No output', resp.tokenUsage);
330
invariant(typeof resp.output === 'string', 'llm-rubric produced malformed response');
332
const firstBrace = resp.output.indexOf('{');
333
const lastBrace = resp.output.lastIndexOf('}');
334
const jsonStr = resp.output.substring(firstBrace, lastBrace + 1);
335
const parsed = JSON.parse(jsonStr) as Partial<GradingResult>;
336
const pass = parsed.pass ?? (typeof parsed.score === 'undefined' ? true : parsed.score > 0);
339
score: parsed.score ?? (pass ? 1.0 : 0.0),
340
reason: parsed.reason || (pass ? 'Grading passed' : 'Grading failed'),
342
total: resp.tokenUsage?.total || 0,
343
prompt: resp.tokenUsage?.prompt || 0,
344
completion: resp.tokenUsage?.completion || 0,
348
return fail(`llm-rubric produced malformed response: ${resp.output}`, resp.tokenUsage);
352
export async function matchesFactuality(
356
grading?: GradingConfig,
357
vars?: Record<string, string | object>,
358
): Promise<Omit<GradingResult, 'assertion'>> {
361
'Cannot grade output without grading config. Specify --grader option or grading config.',
365
const prompt = nunjucks.renderString(grading.rubricPrompt || OPENAI_FACTUALITY_PROMPT, {
366
input: JSON.stringify(input).slice(1, -1),
367
ideal: JSON.stringify(expected).slice(1, -1),
368
completion: JSON.stringify(output).slice(1, -1),
372
let finalProvider = await getAndCheckProvider(
375
DefaultGradingProvider,
378
const resp = await finalProvider.callApi(prompt);
379
if (resp.error || !resp.output) {
380
return fail(resp.error || 'No output', resp.tokenUsage);
383
invariant(typeof resp.output === 'string', 'factuality produced malformed response');
385
const output = resp.output;
387
const answerMatch = output.match(/\s*\(?([a-eA-E])\)/);
390
`Factuality checker output did not match expected format: ${output}`,
394
const option = answerMatch[1].toUpperCase();
398
const scoreLookup: Record<string, number> = {
399
A: grading.factuality?.subset ?? 1,
400
B: grading.factuality?.superset ?? 1,
401
C: grading.factuality?.agree ?? 1,
402
D: grading.factuality?.disagree ?? 0,
403
E: grading.factuality?.differButFactual ?? 1,
407
const passing = Object.keys(scoreLookup).filter((key) => scoreLookup[key] > 0);
408
const failing = Object.keys(scoreLookup).filter((key) => scoreLookup[key] === 0);
410
let pass = passing.includes(option) && !failing.includes(option);
411
const optionReasons: Record<string, string> = {
412
A: `The submitted answer is a subset of the expert answer and is fully consistent with it.`,
413
B: `The submitted answer is a superset of the expert answer and is fully consistent with it.`,
414
C: `The submitted answer contains all the same details as the expert answer.`,
415
D: `There is a disagreement between the submitted answer and the expert answer.`,
416
E: `The answers differ, but these differences don't matter from the perspective of factuality.`,
418
if (optionReasons[option]) {
419
reason = optionReasons[option];
422
reason = `Invalid option: ${option}. Full response from factuality checker: ${resp.output}`;
425
let score = pass ? 1 : 0;
426
if (typeof scoreLookup[option] !== 'undefined') {
427
score = scoreLookup[option];
435
total: resp.tokenUsage?.total || 0,
436
prompt: resp.tokenUsage?.prompt || 0,
437
completion: resp.tokenUsage?.completion || 0,
441
return fail(`Error parsing output: ${(err as Error).message}`, resp.tokenUsage);
445
export async function matchesClosedQa(
449
grading?: GradingConfig,
450
vars?: Record<string, string | object>,
451
): Promise<Omit<GradingResult, 'assertion'>> {
454
'Cannot grade output without grading config. Specify --grader option or grading config.',
458
const prompt = nunjucks.renderString(grading.rubricPrompt || OPENAI_CLOSED_QA_PROMPT, {
459
input: JSON.stringify(input).slice(1, -1),
460
criteria: JSON.stringify(expected).slice(1, -1),
461
completion: JSON.stringify(output).slice(1, -1),
465
let finalProvider = await getAndCheckProvider(
468
DefaultGradingProvider,
469
'model-graded-closedqa check',
471
const resp = await finalProvider.callApi(prompt);
472
if (resp.error || !resp.output) {
473
return fail(resp.error || 'No output', resp.tokenUsage);
476
invariant(typeof resp.output === 'string', 'model-graded-closedqa produced malformed response');
478
const pass = resp.output.endsWith('Y');
481
reason = 'The submission meets the criterion';
482
} else if (resp.output.endsWith('N')) {
483
reason = `The submission does not meet the criterion:\n${resp.output}`;
485
reason = `Model grader produced a malformed response:\n${resp.output}`;
492
total: resp.tokenUsage?.total || 0,
493
prompt: resp.tokenUsage?.prompt || 0,
494
completion: resp.tokenUsage?.completion || 0,
498
return fail(`Error parsing output: ${(err as Error).message}`, resp.tokenUsage);
502
export async function matchesAnswerRelevance(
506
grading?: GradingConfig,
507
): Promise<Omit<GradingResult, 'assertion'>> {
508
let embeddingProvider = await getAndCheckProvider(
511
DefaultEmbeddingProvider,
512
'answer relevancy check',
514
let textProvider = await getAndCheckProvider(
517
DefaultGradingProvider,
518
'answer relevancy check',
527
const candidateQuestions: string[] = [];
528
for (let i = 0; i < 3; i++) {
530
const resp = await textProvider.callApi(
532
ANSWER_RELEVANCY_GENERATE,
539
if (resp.error || !resp.output) {
540
tokensUsed.total += resp.tokenUsage?.total || 0;
541
tokensUsed.prompt += resp.tokenUsage?.prompt || 0;
542
tokensUsed.completion += resp.tokenUsage?.completion || 0;
543
return fail(resp.error || 'No output', tokensUsed);
547
typeof resp.output === 'string',
548
'answer relevancy check produced malformed response',
550
candidateQuestions.push(resp.output);
554
typeof embeddingProvider.callEmbeddingApi === 'function',
555
`Provider ${embeddingProvider.id} must implement callEmbeddingApi for similarity check`,
558
const inputEmbeddingResp = await embeddingProvider.callEmbeddingApi(input);
559
if (inputEmbeddingResp.error || !inputEmbeddingResp.embedding) {
560
tokensUsed.total += inputEmbeddingResp.tokenUsage?.total || 0;
561
tokensUsed.prompt += inputEmbeddingResp.tokenUsage?.prompt || 0;
562
tokensUsed.completion += inputEmbeddingResp.tokenUsage?.completion || 0;
563
return fail(inputEmbeddingResp.error || 'No embedding', tokensUsed);
565
const inputEmbedding = inputEmbeddingResp.embedding;
567
const similarities: number[] = [];
568
for (const question of candidateQuestions) {
569
const resp = await embeddingProvider.callEmbeddingApi(question);
570
tokensUsed.total += resp.tokenUsage?.total || 0;
571
tokensUsed.prompt += resp.tokenUsage?.prompt || 0;
572
tokensUsed.completion += resp.tokenUsage?.completion || 0;
573
if (resp.error || !resp.embedding) {
574
return fail(resp.error || 'No embedding', tokensUsed);
576
similarities.push(cosineSimilarity(inputEmbedding, resp.embedding));
579
const similarity = similarities.reduce((a, b) => a + b, 0) / similarities.length;
580
const pass = similarity >= threshold;
581
const greaterThanReason = `Relevance ${similarity.toFixed(
583
)} is greater than threshold ${threshold}`;
584
const lessThanReason = `Relevance ${similarity.toFixed(2)} is less than threshold ${threshold}`;
589
reason: greaterThanReason,
596
reason: lessThanReason,
601
export async function matchesContextRecall(
605
grading?: GradingConfig,
606
vars?: Record<string, string | object>,
607
): Promise<Omit<GradingResult, 'assertion'>> {
608
let textProvider = await getAndCheckProvider(
611
DefaultGradingProvider,
612
'context recall check',
615
const promptText = nunjucks.renderString(CONTEXT_RECALL, {
616
context: JSON.stringify(context).slice(1, -1),
617
groundTruth: JSON.stringify(groundTruth).slice(1, -1),
621
const resp = await textProvider.callApi(promptText);
622
if (resp.error || !resp.output) {
623
return fail(resp.error || 'No output', resp.tokenUsage);
626
invariant(typeof resp.output === 'string', 'context-recall produced malformed response');
627
const sentences = resp.output.split('\n');
628
const numerator = sentences.reduce(
629
(acc, sentence) => acc + (sentence.includes(CONTEXT_RECALL_ATTRIBUTED_TOKEN) ? 1 : 0),
632
const score = numerator / sentences.length;
633
const pass = score >= threshold;
638
? `Recall ${score.toFixed(2)} is >= ${threshold}`
639
: `Recall ${score.toFixed(2)} is < ${threshold}`,
641
total: resp.tokenUsage?.total || 0,
642
prompt: resp.tokenUsage?.prompt || 0,
643
completion: resp.tokenUsage?.completion || 0,
648
export async function matchesContextRelevance(
652
grading?: GradingConfig,
653
): Promise<Omit<GradingResult, 'assertion'>> {
654
let textProvider = await getAndCheckProvider(
657
DefaultGradingProvider,
658
'context relevance check',
661
const promptText = nunjucks.renderString(CONTEXT_RELEVANCE, {
662
context: JSON.stringify(context).slice(1, -1),
663
query: JSON.stringify(question).slice(1, -1),
666
const resp = await textProvider.callApi(promptText);
667
if (resp.error || !resp.output) {
668
return fail(resp.error || 'No output', resp.tokenUsage);
671
invariant(typeof resp.output === 'string', 'context-relevance produced malformed response');
672
const sentences = resp.output.split('\n');
673
const numerator = sentences.reduce(
674
(acc, sentence) => acc + (sentence.includes(CONTEXT_RELEVANCE_BAD) ? 0 : 1),
677
const score = numerator / sentences.length;
678
const pass = score >= threshold;
683
? `Relevance ${score.toFixed(2)} is >= ${threshold}`
684
: `Relevance ${score.toFixed(2)} is < ${threshold}`,
686
total: resp.tokenUsage?.total || 0,
687
prompt: resp.tokenUsage?.prompt || 0,
688
completion: resp.tokenUsage?.completion || 0,
693
export async function matchesContextFaithfulness(
698
grading?: GradingConfig,
699
vars?: Record<string, string | object>,
700
): Promise<Omit<GradingResult, 'assertion'>> {
701
let textProvider = await getAndCheckProvider(
704
DefaultGradingProvider,
705
'faithfulness check',
708
let promptText = nunjucks.renderString(CONTEXT_FAITHFULNESS_LONGFORM, {
709
question: JSON.stringify(query).slice(1, -1),
710
answer: JSON.stringify(output).slice(1, -1),
714
let resp = await textProvider.callApi(promptText);
715
if (resp.error || !resp.output) {
716
return fail(resp.error || 'No output', resp.tokenUsage);
719
invariant(typeof resp.output === 'string', 'context-faithfulness produced malformed response');
721
let statements = resp.output.split('\n');
722
promptText = nunjucks.renderString(CONTEXT_FAITHFULNESS_NLI_STATEMENTS, {
723
context: JSON.stringify(context).slice(1, -1),
724
statements: JSON.stringify(statements.join('\n')).slice(1, -1),
728
resp = await textProvider.callApi(promptText);
729
if (resp.error || !resp.output) {
730
return fail(resp.error || 'No output', resp.tokenUsage);
733
invariant(typeof resp.output === 'string', 'context-faithfulness produced malformed response');
735
let finalAnswer = 'Final verdict for each statement in order:';
736
finalAnswer = finalAnswer.toLowerCase();
737
let verdicts = resp.output.toLowerCase().trim();
739
if (verdicts.includes(finalAnswer)) {
740
verdicts = verdicts.slice(verdicts.indexOf(finalAnswer) + finalAnswer.length);
742
verdicts.split('.').filter((answer) => answer.trim() !== '' && !answer.includes('yes'))
743
.length / statements.length;
745
score = (verdicts.split('verdict: no').length - 1) / statements.length;
748
let pass = score >= threshold;
753
? `Faithfulness ${score.toFixed(2)} is >= ${threshold}`
754
: `Faithfulness ${score.toFixed(2)} is < ${threshold}`,
756
total: resp.tokenUsage?.total || 0,
757
prompt: resp.tokenUsage?.prompt || 0,
758
completion: resp.tokenUsage?.completion || 0,
763
export async function matchesSelectBest(
766
grading?: GradingConfig,
767
vars?: Record<string, string | object>,
768
): Promise<Omit<GradingResult, 'assertion'>[]> {
771
'select-best assertion must have at least two outputs to compare between',
773
let textProvider = await getAndCheckProvider(
776
DefaultGradingProvider,
780
let promptText = nunjucks.renderString(grading?.rubricPrompt || SELECT_BEST_PROMPT, {
781
criteria: JSON.stringify(criteria).slice(1, -1),
782
outputs: outputs.map((output) => JSON.stringify(output).slice(1, -1)),
786
let resp = await textProvider.callApi(promptText);
787
if (resp.error || !resp.output) {
788
return new Array(outputs.length).fill(fail(resp.error || 'No output', resp.tokenUsage));
791
invariant(typeof resp.output === 'string', 'select-best produced malformed response');
793
const firstDigitMatch = resp.output.trim().match(/\d/);
794
const verdict = firstDigitMatch ? parseInt(firstDigitMatch[0], 10) : NaN;
796
if (isNaN(verdict) || verdict < 0 || verdict >= outputs.length) {
797
return new Array(outputs.length).fill(fail(`Invalid select-best verdict: ${verdict}`));
801
total: resp.tokenUsage?.total || 0,
802
prompt: resp.tokenUsage?.prompt || 0,
803
completion: resp.tokenUsage?.completion || 0,
805
return outputs.map((output, index) => {
806
if (index === verdict) {
810
reason: `Output selected as the best: ${criteria}`,
817
reason: `Output not selected: ${criteria}`,