promptfoo

Форк
0
/
testCases.ts 
363 строки · 11.0 Кб
1
import * as path from 'path';
2
import * as fs from 'fs';
3

4
import yaml from 'js-yaml';
5
import { parse as parsePath } from 'path';
6
import { parse as parseCsv } from 'csv-parse/sync';
7
import { globSync } from 'glob';
8

9
import logger from './logger';
10
import { fetchCsvFromGoogleSheet } from './fetch';
11
import { OpenAiChatCompletionProvider } from './providers/openai';
12
import { testCaseFromCsvRow } from './csv';
13

14
import type {
15
  Assertion,
16
  CsvRow,
17
  TestCase,
18
  TestSuite,
19
  TestSuiteConfig,
20
  UnifiedConfig,
21
  VarMapping,
22
} from './types';
23

24
const SYNTHESIZE_DEFAULT_PROVIDER = 'gpt-4-0125-preview';
25

26
function parseJson(json: string): any | undefined {
27
  try {
28
    return JSON.parse(json);
29
  } catch (err) {
30
    return undefined;
31
  }
32
}
33

34
export async function readVarsFiles(
35
  pathOrGlobs: string | string[],
36
  basePath: string = '',
37
): Promise<Record<string, string | string[] | object>> {
38
  if (typeof pathOrGlobs === 'string') {
39
    pathOrGlobs = [pathOrGlobs];
40
  }
41

42
  const ret: Record<string, string | string[] | object> = {};
43
  for (const pathOrGlob of pathOrGlobs) {
44
    const resolvedPath = path.resolve(basePath, pathOrGlob);
45
    const paths = globSync(resolvedPath, {
46
      windowsPathsNoEscape: true,
47
    });
48

49
    for (const p of paths) {
50
      const yamlData = yaml.load(fs.readFileSync(p, 'utf-8'));
51
      Object.assign(ret, yamlData);
52
    }
53
  }
54

55
  return ret;
56
}
57

58
export async function readStandaloneTestsFile(
59
  varsPath: string,
60
  basePath: string = '',
61
): Promise<TestCase[]> {
62
  // This function is confusingly named - it reads a CSV, JSON, or YAML file of
63
  // TESTS or test equivalents.
64
  const resolvedVarsPath = path.resolve(basePath, varsPath);
65
  const fileExtension = parsePath(varsPath).ext.slice(1);
66
  let rows: CsvRow[] = [];
67

68
  if (varsPath.startsWith('https://docs.google.com/spreadsheets/')) {
69
    const csvData = await fetchCsvFromGoogleSheet(varsPath);
70
    rows = parseCsv(csvData, { columns: true });
71
  } else if (fileExtension === 'csv') {
72
    rows = parseCsv(fs.readFileSync(resolvedVarsPath, 'utf-8'), { columns: true });
73
  } else if (fileExtension === 'json') {
74
    rows = parseJson(fs.readFileSync(resolvedVarsPath, 'utf-8'));
75
  } else if (fileExtension === 'yaml' || fileExtension === 'yml') {
76
    rows = yaml.load(fs.readFileSync(resolvedVarsPath, 'utf-8')) as unknown as any;
77
  }
78

79
  return rows.map((row, idx) => {
80
    const test = testCaseFromCsvRow(row);
81
    test.description = `Row #${idx + 1}`;
82
    return test;
83
  });
84
}
85

86
type TestCaseWithVarsFile = TestCase<
87
  Record<string, string | string[] | object> | string | string[]
88
>;
89
export async function readTest(
90
  test: string | TestCaseWithVarsFile,
91
  basePath: string = '',
92
): Promise<TestCase> {
93
  const loadTestWithVars = async (
94
    testCase: TestCaseWithVarsFile,
95
    testBasePath: string,
96
  ): Promise<TestCase> => {
97
    const ret: TestCase = { ...testCase, vars: undefined };
98
    if (typeof testCase.vars === 'string' || Array.isArray(testCase.vars)) {
99
      ret.vars = await readVarsFiles(testCase.vars, testBasePath);
100
    } else {
101
      ret.vars = testCase.vars;
102
    } /*else if (typeof testCase.vars === 'object') {
103
      const vars: Record<string, string | string[] | object> = {};
104
      for (const [key, value] of Object.entries(testCase.vars)) {
105
        if (typeof value === 'string' && value.startsWith('file://')) {
106
          // Load file from disk.
107
          const filePath = path.resolve(testBasePath, value.slice('file://'.length));
108
          if (filePath.endsWith('.yaml') || filePath.endsWith('.yml')) {
109
            vars[key] = (yaml.load(fs.readFileSync(filePath, 'utf-8')) as string).trim();
110
          } else {
111
            vars[key] = fs.readFileSync(filePath, 'utf-8').trim();
112
          }
113
        } else {
114
          // This is a normal key:value.
115
          vars[key] = value;
116
        }
117
      }
118
      ret.vars = vars;
119
    }*/
120
    return ret;
121
  };
122

123
  let testCase: TestCase;
124

125
  if (typeof test === 'string') {
126
    const testFilePath = path.resolve(basePath, test);
127
    const testBasePath = path.dirname(testFilePath);
128
    const rawTestCase = yaml.load(fs.readFileSync(testFilePath, 'utf-8')) as TestCaseWithVarsFile;
129
    testCase = await loadTestWithVars(rawTestCase, testBasePath);
130
  } else {
131
    testCase = await loadTestWithVars(test, basePath);
132
  }
133

134
  // Validation of the shape of test
135
  if (!testCase.assert && !testCase.vars && !testCase.options) {
136
    throw new Error(
137
      `Test case must have either assert, vars, or options property. Instead got ${JSON.stringify(
138
        testCase,
139
        null,
140
        2,
141
      )}`,
142
    );
143
  }
144

145
  return testCase;
146
}
147

148
export async function readTests(
149
  tests: TestSuiteConfig['tests'],
150
  basePath: string = '',
151
): Promise<TestCase[]> {
152
  const ret: TestCase[] = [];
153

154
  const loadTestsFromGlob = async (loadTestsGlob: string) => {
155
    const resolvedPath = path.resolve(basePath, loadTestsGlob);
156
    const testFiles = globSync(resolvedPath, {
157
      windowsPathsNoEscape: true,
158
    });
159
    const ret = [];
160
    for (const testFile of testFiles) {
161
      let testCases: TestCase[] | undefined;
162
      if (testFile.endsWith('.csv')) {
163
        testCases = await readStandaloneTestsFile(testFile, basePath);
164
      } else if (testFile.endsWith('.yaml') || testFile.endsWith('.yml')) {
165
        testCases = yaml.load(fs.readFileSync(testFile, 'utf-8')) as TestCase[];
166
      } else if (testFile.endsWith('.json')) {
167
        testCases = require(testFile);
168
      } else {
169
        throw new Error(`Unsupported file type for test file: ${testFile}`);
170
      }
171
      if (testCases) {
172
        for (const testCase of testCases) {
173
          ret.push(await readTest(testCase, path.dirname(testFile)));
174
        }
175
      }
176
    }
177
    return ret;
178
  };
179

180
  if (typeof tests === 'string') {
181
    if (tests.endsWith('yaml') || tests.endsWith('yml')) {
182
      // Points to a tests file with multiple test cases
183
      return loadTestsFromGlob(tests);
184
    } else {
185
      // Points to a legacy vars.csv
186
      return readStandaloneTestsFile(tests, basePath);
187
    }
188
  } else if (Array.isArray(tests)) {
189
    for (const globOrTest of tests) {
190
      if (typeof globOrTest === 'string') {
191
        // Resolve globs
192
        ret.push(...(await loadTestsFromGlob(globOrTest)));
193
      } else {
194
        // It's just a TestCase
195
        ret.push(await readTest(globOrTest, basePath));
196
      }
197
    }
198
  }
199

200
  return ret;
201
}
202

203
interface SynthesizeOptions {
204
  prompts: string[];
205
  instructions?: string;
206
  tests: TestCase[];
207
  numPersonas?: number;
208
  numTestCasesPerPersona?: number;
209
}
210

211
export async function synthesizeFromTestSuite(
212
  testSuite: TestSuite,
213
  options: Partial<SynthesizeOptions>,
214
) {
215
  return synthesize({
216
    ...options,
217
    prompts: testSuite.prompts.map((prompt) => prompt.raw),
218
    tests: testSuite.tests || [],
219
  });
220
}
221

222
export async function synthesize({
223
  prompts,
224
  instructions,
225
  tests,
226
  numPersonas,
227
  numTestCasesPerPersona,
228
}: SynthesizeOptions) {
229
  if (prompts.length < 1) {
230
    throw new Error('Dataset synthesis requires at least one prompt.');
231
  }
232

233
  numPersonas = numPersonas || 5;
234
  numTestCasesPerPersona = numTestCasesPerPersona || 3;
235

236
  let progressBar;
237
  if (process.env.LOG_LEVEL !== 'debug') {
238
    const cliProgress = await import('cli-progress');
239
    progressBar = new cliProgress.SingleBar({}, cliProgress.Presets.shades_classic);
240
    const totalProgressSteps = 1 + numPersonas * numTestCasesPerPersona;
241
    progressBar.start(totalProgressSteps, 0);
242
  }
243

244
  logger.debug(
245
    `Starting dataset synthesis. We'll begin by generating up to ${numPersonas} personas. Each persona will be used to generate ${numTestCasesPerPersona} test cases.`,
246
  );
247

248
  // Consider the following prompt for an LLM application: {{prompt}}. List up to 5 user personas that would send this prompt.
249
  logger.debug(`\nGenerating user personas from ${prompts.length} prompts...`);
250
  const provider = new OpenAiChatCompletionProvider(SYNTHESIZE_DEFAULT_PROVIDER, {
251
    config: {
252
      temperature: 1.0,
253
      response_format: {
254
        type: 'json_object',
255
      },
256
    },
257
  });
258
  const promptsString = `<Prompts>
259
${prompts.map((prompt) => `<Prompt>\n${prompt}\n</Prompt>`).join('\n')}
260
</Prompts>`;
261
  const resp = await provider.callApi(
262
    `Consider the following prompt${prompts.length > 1 ? 's' : ''} for an LLM application:
263
${promptsString}
264

265
List up to ${numPersonas} user personas that would send ${
266
      prompts.length > 1 ? 'these prompts' : 'this prompt'
267
    }. Your response should be JSON of the form {personas: string[]}`,
268
  );
269

270
  const personas = (JSON.parse(resp.output as string) as { personas: string[] }).personas;
271
  logger.debug(
272
    `\nGenerated ${personas.length} personas:\n${personas.map((p) => `  - ${p}`).join('\n')}`,
273
  );
274

275
  if (progressBar) {
276
    progressBar.increment();
277
  }
278

279
  // Extract variable names from the nunjucks template in the prompts
280
  const variableRegex = /{{\s*(\w+)\s*}}/g;
281
  const variables = new Set();
282
  for (const prompt of prompts) {
283
    let match;
284
    while ((match = variableRegex.exec(prompt)) !== null) {
285
      variables.add(match[1]);
286
    }
287
  }
288
  logger.debug(
289
    `\nExtracted ${variables.size} variables from prompts:\n${Array.from(variables)
290
      .map((v) => `  - ${v}`)
291
      .join('\n')}`,
292
  );
293

294
  const existingTests =
295
    `Here are some existing tests:` +
296
    tests
297
      .map((test) => {
298
        if (!test.vars) {
299
          return;
300
        }
301
        return `<Test>
302
${JSON.stringify(test.vars, null, 2)}
303
</Test>
304
    `;
305
      })
306
      .filter(Boolean)
307
      .slice(0, 100)
308
      .join('\n');
309

310
  // For each user persona, we will generate a map of variable names to values
311
  const testCaseVars: VarMapping[] = [];
312
  for (let i = 0; i < personas.length; i++) {
313
    const persona = personas[i];
314
    logger.debug(`\nGenerating test cases for persona ${i + 1}...`);
315
    // Construct the prompt for the LLM to generate variable values
316
    const personaPrompt = `Consider ${
317
      prompts.length > 1 ? 'these prompts' : 'this prompt'
318
    }, which contains some {{variables}}: 
319
${promptsString}
320

321
This is your persona:
322
<Persona>
323
${persona}
324
</Persona>
325

326
${existingTests}
327

328
Fully embody this persona and determine a value for each variable, such that the prompt would be sent by this persona.
329

330
You are a tester, so try to think of ${numTestCasesPerPersona} sets of values that would be interesting or unusual to test. ${
331
      instructions || ''
332
    }
333

334
Your response should contain a JSON map of variable names to values, of the form {vars: {${Array.from(
335
      variables,
336
    )
337
      .map((varName) => `${varName}: string`)
338
      .join(', ')}}[]}`;
339
    // Call the LLM API with the constructed prompt
340
    const personaResponse = await provider.callApi(personaPrompt);
341
    const parsed = JSON.parse(personaResponse.output as string) as {
342
      vars: VarMapping[];
343
    };
344
    for (const vars of parsed.vars) {
345
      logger.debug(`${JSON.stringify(vars, null, 2)}`);
346
      testCaseVars.push(vars);
347
      if (progressBar) {
348
        progressBar.increment();
349
      }
350
    }
351
  }
352

353
  if (progressBar) {
354
    progressBar.stop();
355
  }
356

357
  // Dedup test case vars
358
  const uniqueTestCaseStrings = new Set(testCaseVars.map((testCase) => JSON.stringify(testCase)));
359
  const dedupedTestCaseVars: VarMapping[] = Array.from(uniqueTestCaseStrings).map((testCase) =>
360
    JSON.parse(testCase),
361
  );
362
  return dedupedTestCaseVars;
363
}
364

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.