google-research
246 строк · 6.4 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""DePlot Prompts."""
17
18import collections19from collections.abc import Callable, Iterable, Mapping, Sequence20import enum21import json22import random23import time24from typing import TypeVar25
26from absl import flags27import openai28from pix2struct import metrics as pix2struct_metrics29from t5.evaluation import metrics as t5_metrics30import tensorflow as tf31
32
33
34_OPENAI_CREDENTIALS = flags.DEFINE_list(35'openai_credentials', None, 'Credentials to call OpenAI.', required=True)36
37T = TypeVar('T')38TFn = Callable[Ellipsis, T]39
40
41class Model(enum.Enum):42GPT3 = 'gpt3'43
44
45def retry(46try_count = 3,47sleep_seconds = 2, # pylint: disable=unused-argument48):49"""Retry decorator."""50
51def decorator(fn):52
53def newfn(*args, **kwargs):54for idx in range(try_count):55try:56return fn(*args, **kwargs)57except ValueError as e:58time.sleep(sleep_seconds * (2**idx))59if idx == try_count - 1:60raise ValueError('No more retries') from e61
62return newfn63
64return decorator65
66
67
68
69@retry(try_count=3, sleep_seconds=1)70def _call_openai(71prompt,72engine,73max_decode_steps,74temperature,75top_p = 1,76frequency_penalty = 0,77presence_penalty = 0,78samples = 1,79stop = ('Q:', 'A:', 'Summary:', '\n\n')):80"""Issues a completion request to the engine, while retrying on failure.81
82Args:
83prompt: The prompt to send.
84engine: Model engine to use.
85max_decode_steps: The max_tokens parameter to send to the engine.
86temperature: Sampling temperature.
87top_p: Ratio of likelihood weighted token options to allow while sampling.
88frequency_penalty: Pentalty for the frequency of repeated tokens.
89presence_penalty: Penalty for the existence repeated tokens.
90samples: Number of outputs to generate.
91stop: Sequence of strings that elicit an end to decoding
92
93Returns:
94Text completion
95"""
96openai.api_key = random.choice(_OPENAI_CREDENTIALS.value)97
98try:99reply = openai.Completion.create(100engine=engine,101prompt=prompt,102temperature=temperature,103max_tokens=max_decode_steps,104top_p=top_p,105frequency_penalty=frequency_penalty,106presence_penalty=presence_penalty,107n=samples,108stop=stop)109return [choice['text'] for choice in reply['choices']] if reply else []110
111except openai.error.RateLimitError as e:112print('Sleeping 60 secs.')113time.sleep(60)114raise ValueError('RateLimitError') from e115
116
117def call_model(118model,119prompt,120use_code,121temperature,122max_decode_steps,123samples,124):125"""Calls model given a prompt."""126results = []127while len(results) < samples:128if model == Model.GPT3:129results.extend(130_call_openai(131prompt,132engine='code-davinci-002' if use_code else 'text-davinci-003',133temperature=temperature,134max_decode_steps=max_decode_steps,135samples=samples))136else:137raise ValueError(f'Unknown model_type={model}')138return results[:samples]139
140
141def chunks(142generator,143chunk_size,144filter_fn):145"""Splits generator into chunks."""146chunk = []147idx = 0148skipped = 0149
150for item in generator:151if not filter_fn(item):152skipped += 1153continue154if len(chunk) >= chunk_size:155yield idx, chunk156idx += 1157chunk = [item]158else:159chunk.append(item)160
161if chunk:162yield idx, chunk163print('Total skipped', skipped)164
165
166def _majority(predictions):167"""Finds most frequent result among the first N predictions for each N."""168result = []169counter = collections.Counter()170for prediction in predictions:171if prediction:172counter[prediction] += 1173if counter:174result.append(counter.most_common(1)[0][0])175else:176result.append('')177return result178
179
180def _exec(code):181"""Executed model output and returns the `ans` variable."""182
183def execute(x):184try:185exec(x) # pylint: disable=exec-used186answer = locals().get('ans', '')187if isinstance(answer, str):188return answer189elif isinstance(answer, bool):190return 'Yes' if answer else 'No'191elif isinstance(answer, collections.abc.Sequence):192return ', '.join(str(a) for a in answer)193return str(answer)194except Exception: # pylint: disable=broad-except195return ''196
197return execute(code)198
199
200def _extract_answer(prediction):201output = prediction.split('\n\n')[0]202if output.lower().startswith('#python'):203return _exec(output)204return output.split('answer is')[-1].strip().rstrip('.').strip()205
206
207def _extract_answers(predictions):208return [_extract_answer(output) for output in predictions]209
210
211def compute_metrics(files, is_qa):212"""Computes the metrics given the list of prediction files."""213targets = []214predictions = []215if is_qa:216def metric_fn(targets, predictions):217return dict(218relaxed_accuracy=pix2struct_metrics.aggregate_metrics(219targets=targets,220predictions=predictions,221metric_fn=pix2struct_metrics.relaxed_correctness,222)223)224
225else:226metric_fn = t5_metrics.bleu227for predictions_file in files:228with tf.io.gfile.GFile(predictions_file) as f:229for line in f:230prediction = json.loads(line)231# Each prediction line contains a list of targets but only one is used.232targets.append(prediction['target'][0])233if is_qa:234predictions.append(235_majority(_extract_answers(prediction['predictions']))236)237else:238predictions.append(239[p.split('\n')[0] for p in prediction['predictions']]240)241metrics = {}242for idx, sampled_predictions in enumerate(zip(*predictions)):243metric = metric_fn(targets, list(sampled_predictions))244for key, value in metric.items():245metrics[f'{key}_maj{idx + 1}'] = value246return metrics247