google-research

Форк
0
277 строк · 8.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Binary of evaluating instruction following. See README.md."""
17

18
import collections
19
import dataclasses
20
import json
21
import os
22
from typing import Dict, Optional, Sequence, Union
23

24
from absl import app
25
from absl import flags
26
from absl import logging
27

28
from instruction_following_eval import instructions_registry
29

30

31
_INPUT_DATA = flags.DEFINE_string(
32
    "input_data", None, "path to input data", required=True
33
)
34

35
_INPUT_RESPONSE_DATA = flags.DEFINE_string(
36
    "input_response_data", None, "path to input response data", required=False
37
)
38

39
_OUTPUT_DIR = flags.DEFINE_string(
40
    "output_dir",
41
    None,
42
    "Output directory for inference and eval results.",
43
    required=True,
44
)
45

46

47
@dataclasses.dataclass
48
class InputExample:
49
  key: int
50
  instruction_id_list: list[str]
51
  prompt: str
52
  kwargs: list[Dict[str, Optional[Union[str, int]]]]
53

54

55
@dataclasses.dataclass
56
class OutputExample:
57
  instruction_id_list: list[str]
58
  prompt: str
59
  response: str
60
  follow_all_instructions: bool
61
  follow_instruction_list: list[bool]
62

63

64
def read_prompt_list(input_jsonl_filename):
65
  """Read inputs from jsonl."""
66
  inputs = []
67
  with open(input_jsonl_filename, "r") as f:
68
    for l in f:
69
      example = json.loads(l)
70
      inputs.append(
71
          InputExample(key=example["key"],
72
                       instruction_id_list=example["instruction_id_list"],
73
                       prompt=example["prompt"],
74
                       kwargs=example["kwargs"]))
75
  return inputs
76

77

78
def write_outputs(output_jsonl_filename, outputs):
79
  """Writes outputs to jsonl."""
80
  assert outputs
81
  with open(output_jsonl_filename, "w") as f:
82
    for o in outputs:
83
      f.write(
84
          json.dumps(
85
              {
86
                  attr_name: o.__getattribute__(attr_name)
87
                  for attr_name in [
88
                      name for name in dir(o) if not name.startswith("_")
89
                  ]
90
              }
91
          )
92
      )
93
      f.write("\n")
94

95

96
def test_instruction_following_strict(
97
    inp,
98
    prompt_to_response,
99
):
100
  """Tests response to see if instrutions are followed."""
101
  response = prompt_to_response[inp.prompt]
102
  instruction_list = inp.instruction_id_list
103
  is_following_list = []
104

105
  for index, instruction_id in enumerate(instruction_list):
106
    instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
107
    instruction = instruction_cls(instruction_id)
108

109
    instruction.build_description(**inp.kwargs[index])
110
    args = instruction.get_instruction_args()
111
    if args and "prompt" in args:
112
      instruction.build_description(prompt=inp.prompt)
113

114
    if response.strip() and instruction.check_following(response):
115
      is_following_list.append(True)
116
    else:
117
      is_following_list.append(False)
118

119
  return OutputExample(
120
      instruction_id_list=inp.instruction_id_list,
121
      prompt=inp.prompt,
122
      response=response,
123
      follow_all_instructions=all(is_following_list),
124
      follow_instruction_list=is_following_list,
125
  )
126

127

128
def test_instruction_following_loose(
129
    inp,
130
    prompt_to_response,
131
):
132
  """Tests response for an upper bound for following instructions."""
133
  response = prompt_to_response[inp.prompt]
134
  r = response.split("\n")
135
  response_remove_first = "\n".join(r[1:]).strip()
136
  response_remove_last = "\n".join(r[:-1]).strip()
137
  response_remove_both = "\n".join(r[1:-1]).strip()
138
  revised_response = response.replace("*", "")
139
  revised_response_remove_first = response_remove_first.replace("*", "")
140
  revised_response_remove_last = response_remove_last.replace("*", "")
141
  revised_response_remove_both = response_remove_both.replace("*", "")
142
  all_responses = [
143
      response,
144
      revised_response,
145
      response_remove_first,
146
      response_remove_last,
147
      response_remove_both,
148
      revised_response_remove_first,
149
      revised_response_remove_last,
150
      revised_response_remove_both,
151
  ]
152
  instruction_list = inp.instruction_id_list
153
  is_following_list = []
154

155
  for index, instruction_id in enumerate(instruction_list):
156
    instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
157
    instruction = instruction_cls(instruction_id)
158

159
    instruction.build_description(**inp.kwargs[index])
160
    args = instruction.get_instruction_args()
161
    if args and "prompt" in args:
162
      instruction.build_description(prompt=inp.prompt)
163

164
    is_following = False
165
    for r in all_responses:
166
      if r.strip() and instruction.check_following(r):
167
        is_following = True
168
        break
169

170
    is_following_list.append(is_following)
171

172
  return OutputExample(
173
      instruction_id_list=inp.instruction_id_list,
174
      prompt=inp.prompt,
175
      response=response,
176
      follow_all_instructions=all(is_following_list),
177
      follow_instruction_list=is_following_list,
178
  )
179

180

181
def read_prompt_to_response_dict(input_jsonl_filename):
182
  """Creates dictionary matching prompt and response."""
183
  return_dict = {}
184
  with open(input_jsonl_filename, "r") as f:
185
    for l in f:
186
      example = json.loads(l)
187
      return_dict[example["prompt"]] = example["response"]
188
  return return_dict
189

190

191
def print_report(outputs):
192
  """Prints a report on accuracy scores."""
193

194
  prompt_total = 0
195
  prompt_correct = 0
196
  instruction_total = 0
197
  instruction_correct = 0
198

199
  tier0_total = collections.defaultdict(int)
200
  tier0_correct = collections.defaultdict(int)
201

202
  tier1_total = collections.defaultdict(int)
203
  tier1_correct = collections.defaultdict(int)
204

205
  for example in outputs:
206
    follow_instruction_list = example.follow_instruction_list
207
    instruction_id_list = example.instruction_id_list
208

209
    prompt_total += 1
210
    if all(follow_instruction_list):
211
      prompt_correct += 1
212

213
    instruction_total += len(instruction_id_list)
214
    instruction_correct += sum(follow_instruction_list)
215

216
    for instruction_id, followed_or_not in zip(
217
        instruction_id_list, follow_instruction_list
218
    ):
219
      instruction_id = instruction_id.split(":")[0]
220
      tier0_total[instruction_id] += 1
221
      if followed_or_not:
222
        tier0_correct[instruction_id] += 1
223

224
    for instruction_id, followed_or_not in zip(
225
        instruction_id_list, follow_instruction_list
226
    ):
227
      tier1_total[instruction_id] += 1
228
      if followed_or_not:
229
        tier1_correct[instruction_id] += 1
230

231
  print(f"prompt-level: {prompt_correct / prompt_total}")
232
  print(f"instruction-level: {instruction_correct / instruction_total}")
233
  print()
234
  for instruction_id in sorted(tier0_total.keys()):
235
    accuracy = tier0_correct[instruction_id] / tier0_total[instruction_id]
236
    print(f"{instruction_id} {accuracy}")
237
  print()
238
  for instruction_id in sorted(tier1_total.keys()):
239
    accuracy = tier1_correct[instruction_id] / tier1_total[instruction_id]
240
    print(f"{instruction_id} {accuracy}")
241

242

243
def main(argv):
244
  if len(argv) > 1:
245
    raise app.UsageError("Too many command-line arguments.")
246

247
  inputs = read_prompt_list(_INPUT_DATA.value)
248
  prompt_to_response = read_prompt_to_response_dict(
249
      _INPUT_RESPONSE_DATA.value)
250

251
  # get instruction following results
252
  for func, output_file_name in [
253
      (test_instruction_following_strict, "eval_results_strict"),
254
      (test_instruction_following_loose, "eval_results_loose"),
255
  ]:
256
    logging.info("Generating %s...", output_file_name)
257
    outputs = []
258
    for inp in inputs:
259
      outputs.append(func(inp, prompt_to_response))
260
    follow_all_instructions = [o.follow_all_instructions for o in outputs]
261
    accuracy = sum(follow_all_instructions) / len(outputs)
262
    logging.info("Accuracy: %f", accuracy)
263

264
    output_file_name = os.path.join(
265
        _OUTPUT_DIR.value, output_file_name + ".jsonl"
266
    )
267
    write_outputs(output_file_name, outputs)
268
    logging.info("Generated: %s", output_file_name)
269

270
    # Prints instruction following accuracy report.
271
    print("=" * 64)
272
    print(f"{output_file_name} Accuracy Scores:")
273
    print_report(outputs)
274

275

276
if __name__ == "__main__":
277
  app.run(main)
278

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.