google-research

Форк
0
/
llm_utils.py 
246 строк · 6.4 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""DePlot Prompts."""
17

18
import collections
19
from collections.abc import Callable, Iterable, Mapping, Sequence
20
import enum
21
import json
22
import random
23
import time
24
from typing import TypeVar
25

26
from absl import flags
27
import openai
28
from pix2struct import metrics as pix2struct_metrics
29
from t5.evaluation import metrics as t5_metrics
30
import tensorflow as tf
31

32

33

34
_OPENAI_CREDENTIALS = flags.DEFINE_list(
35
    'openai_credentials', None, 'Credentials to call OpenAI.', required=True)
36

37
T = TypeVar('T')
38
TFn = Callable[Ellipsis, T]
39

40

41
class Model(enum.Enum):
42
  GPT3 = 'gpt3'
43

44

45
def retry(
46
    try_count = 3,
47
    sleep_seconds = 2,  # pylint: disable=unused-argument
48
):
49
  """Retry decorator."""
50

51
  def decorator(fn):
52

53
    def newfn(*args, **kwargs):
54
      for idx in range(try_count):
55
        try:
56
          return fn(*args, **kwargs)
57
        except ValueError as e:
58
          time.sleep(sleep_seconds * (2**idx))
59
          if idx == try_count - 1:
60
            raise ValueError('No more retries') from e
61

62
    return newfn
63

64
  return decorator
65

66

67

68

69
@retry(try_count=3, sleep_seconds=1)
70
def _call_openai(
71
    prompt,
72
    engine,
73
    max_decode_steps,
74
    temperature,
75
    top_p = 1,
76
    frequency_penalty = 0,
77
    presence_penalty = 0,
78
    samples = 1,
79
    stop = ('Q:', 'A:', 'Summary:', '\n\n')):
80
  """Issues a completion request to the engine, while retrying on failure.
81

82
  Args:
83
    prompt: The prompt to send.
84
    engine: Model engine to use.
85
    max_decode_steps: The max_tokens parameter to send to the engine.
86
    temperature: Sampling temperature.
87
    top_p: Ratio of likelihood weighted token options to allow while sampling.
88
    frequency_penalty: Pentalty for the frequency of repeated tokens.
89
    presence_penalty: Penalty for the existence repeated tokens.
90
    samples: Number of outputs to generate.
91
    stop: Sequence of strings that elicit an end to decoding
92

93
  Returns:
94
    Text completion
95
  """
96
  openai.api_key = random.choice(_OPENAI_CREDENTIALS.value)
97

98
  try:
99
    reply = openai.Completion.create(
100
        engine=engine,
101
        prompt=prompt,
102
        temperature=temperature,
103
        max_tokens=max_decode_steps,
104
        top_p=top_p,
105
        frequency_penalty=frequency_penalty,
106
        presence_penalty=presence_penalty,
107
        n=samples,
108
        stop=stop)
109
    return [choice['text'] for choice in reply['choices']] if reply else []
110

111
  except openai.error.RateLimitError as e:
112
    print('Sleeping 60 secs.')
113
    time.sleep(60)
114
    raise ValueError('RateLimitError') from e
115

116

117
def call_model(
118
    model,
119
    prompt,
120
    use_code,
121
    temperature,
122
    max_decode_steps,
123
    samples,
124
):
125
  """Calls model given a prompt."""
126
  results = []
127
  while len(results) < samples:
128
    if model == Model.GPT3:
129
      results.extend(
130
          _call_openai(
131
              prompt,
132
              engine='code-davinci-002' if use_code else 'text-davinci-003',
133
              temperature=temperature,
134
              max_decode_steps=max_decode_steps,
135
              samples=samples))
136
    else:
137
      raise ValueError(f'Unknown model_type={model}')
138
  return results[:samples]
139

140

141
def chunks(
142
    generator,
143
    chunk_size,
144
    filter_fn):
145
  """Splits generator into chunks."""
146
  chunk = []
147
  idx = 0
148
  skipped = 0
149

150
  for item in generator:
151
    if not filter_fn(item):
152
      skipped += 1
153
      continue
154
    if len(chunk) >= chunk_size:
155
      yield idx, chunk
156
      idx += 1
157
      chunk = [item]
158
    else:
159
      chunk.append(item)
160

161
  if chunk:
162
    yield idx, chunk
163
  print('Total skipped', skipped)
164

165

166
def _majority(predictions):
167
  """Finds most frequent result among the first N predictions for each N."""
168
  result = []
169
  counter = collections.Counter()
170
  for prediction in predictions:
171
    if prediction:
172
      counter[prediction] += 1
173
    if counter:
174
      result.append(counter.most_common(1)[0][0])
175
    else:
176
      result.append('')
177
  return result
178

179

180
def _exec(code):
181
  """Executed model output and returns the `ans` variable."""
182

183
  def execute(x):
184
    try:
185
      exec(x)  # pylint: disable=exec-used
186
      answer = locals().get('ans', '')
187
      if isinstance(answer, str):
188
        return answer
189
      elif isinstance(answer, bool):
190
        return 'Yes' if answer else 'No'
191
      elif isinstance(answer, collections.abc.Sequence):
192
        return ', '.join(str(a) for a in answer)
193
      return str(answer)
194
    except Exception:  # pylint: disable=broad-except
195
      return ''
196

197
  return execute(code)
198

199

200
def _extract_answer(prediction):
201
  output = prediction.split('\n\n')[0]
202
  if output.lower().startswith('#python'):
203
    return _exec(output)
204
  return output.split('answer is')[-1].strip().rstrip('.').strip()
205

206

207
def _extract_answers(predictions):
208
  return [_extract_answer(output) for output in predictions]
209

210

211
def compute_metrics(files, is_qa):
212
  """Computes the metrics given the list of prediction files."""
213
  targets = []
214
  predictions = []
215
  if is_qa:
216
    def metric_fn(targets, predictions):
217
      return dict(
218
          relaxed_accuracy=pix2struct_metrics.aggregate_metrics(
219
              targets=targets,
220
              predictions=predictions,
221
              metric_fn=pix2struct_metrics.relaxed_correctness,
222
          )
223
      )
224

225
  else:
226
    metric_fn = t5_metrics.bleu
227
  for predictions_file in files:
228
    with tf.io.gfile.GFile(predictions_file) as f:
229
      for line in f:
230
        prediction = json.loads(line)
231
        # Each prediction line contains a list of targets but only one is used.
232
        targets.append(prediction['target'][0])
233
        if is_qa:
234
          predictions.append(
235
              _majority(_extract_answers(prediction['predictions']))
236
          )
237
        else:
238
          predictions.append(
239
              [p.split('\n')[0] for p in prediction['predictions']]
240
          )
241
  metrics = {}
242
  for idx, sampled_predictions in enumerate(zip(*predictions)):
243
    metric = metric_fn(targets, list(sampled_predictions))
244
    for key, value in metric.items():
245
      metrics[f'{key}_maj{idx + 1}'] = value
246
  return metrics
247

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.