google-research
277 строк · 8.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Binary of evaluating instruction following. See README.md."""
17
18import collections
19import dataclasses
20import json
21import os
22from typing import Dict, Optional, Sequence, Union
23
24from absl import app
25from absl import flags
26from absl import logging
27
28from instruction_following_eval import instructions_registry
29
30
31_INPUT_DATA = flags.DEFINE_string(
32"input_data", None, "path to input data", required=True
33)
34
35_INPUT_RESPONSE_DATA = flags.DEFINE_string(
36"input_response_data", None, "path to input response data", required=False
37)
38
39_OUTPUT_DIR = flags.DEFINE_string(
40"output_dir",
41None,
42"Output directory for inference and eval results.",
43required=True,
44)
45
46
47@dataclasses.dataclass
48class InputExample:
49key: int
50instruction_id_list: list[str]
51prompt: str
52kwargs: list[Dict[str, Optional[Union[str, int]]]]
53
54
55@dataclasses.dataclass
56class OutputExample:
57instruction_id_list: list[str]
58prompt: str
59response: str
60follow_all_instructions: bool
61follow_instruction_list: list[bool]
62
63
64def read_prompt_list(input_jsonl_filename):
65"""Read inputs from jsonl."""
66inputs = []
67with open(input_jsonl_filename, "r") as f:
68for l in f:
69example = json.loads(l)
70inputs.append(
71InputExample(key=example["key"],
72instruction_id_list=example["instruction_id_list"],
73prompt=example["prompt"],
74kwargs=example["kwargs"]))
75return inputs
76
77
78def write_outputs(output_jsonl_filename, outputs):
79"""Writes outputs to jsonl."""
80assert outputs
81with open(output_jsonl_filename, "w") as f:
82for o in outputs:
83f.write(
84json.dumps(
85{
86attr_name: o.__getattribute__(attr_name)
87for attr_name in [
88name for name in dir(o) if not name.startswith("_")
89]
90}
91)
92)
93f.write("\n")
94
95
96def test_instruction_following_strict(
97inp,
98prompt_to_response,
99):
100"""Tests response to see if instrutions are followed."""
101response = prompt_to_response[inp.prompt]
102instruction_list = inp.instruction_id_list
103is_following_list = []
104
105for index, instruction_id in enumerate(instruction_list):
106instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
107instruction = instruction_cls(instruction_id)
108
109instruction.build_description(**inp.kwargs[index])
110args = instruction.get_instruction_args()
111if args and "prompt" in args:
112instruction.build_description(prompt=inp.prompt)
113
114if response.strip() and instruction.check_following(response):
115is_following_list.append(True)
116else:
117is_following_list.append(False)
118
119return OutputExample(
120instruction_id_list=inp.instruction_id_list,
121prompt=inp.prompt,
122response=response,
123follow_all_instructions=all(is_following_list),
124follow_instruction_list=is_following_list,
125)
126
127
128def test_instruction_following_loose(
129inp,
130prompt_to_response,
131):
132"""Tests response for an upper bound for following instructions."""
133response = prompt_to_response[inp.prompt]
134r = response.split("\n")
135response_remove_first = "\n".join(r[1:]).strip()
136response_remove_last = "\n".join(r[:-1]).strip()
137response_remove_both = "\n".join(r[1:-1]).strip()
138revised_response = response.replace("*", "")
139revised_response_remove_first = response_remove_first.replace("*", "")
140revised_response_remove_last = response_remove_last.replace("*", "")
141revised_response_remove_both = response_remove_both.replace("*", "")
142all_responses = [
143response,
144revised_response,
145response_remove_first,
146response_remove_last,
147response_remove_both,
148revised_response_remove_first,
149revised_response_remove_last,
150revised_response_remove_both,
151]
152instruction_list = inp.instruction_id_list
153is_following_list = []
154
155for index, instruction_id in enumerate(instruction_list):
156instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
157instruction = instruction_cls(instruction_id)
158
159instruction.build_description(**inp.kwargs[index])
160args = instruction.get_instruction_args()
161if args and "prompt" in args:
162instruction.build_description(prompt=inp.prompt)
163
164is_following = False
165for r in all_responses:
166if r.strip() and instruction.check_following(r):
167is_following = True
168break
169
170is_following_list.append(is_following)
171
172return OutputExample(
173instruction_id_list=inp.instruction_id_list,
174prompt=inp.prompt,
175response=response,
176follow_all_instructions=all(is_following_list),
177follow_instruction_list=is_following_list,
178)
179
180
181def read_prompt_to_response_dict(input_jsonl_filename):
182"""Creates dictionary matching prompt and response."""
183return_dict = {}
184with open(input_jsonl_filename, "r") as f:
185for l in f:
186example = json.loads(l)
187return_dict[example["prompt"]] = example["response"]
188return return_dict
189
190
191def print_report(outputs):
192"""Prints a report on accuracy scores."""
193
194prompt_total = 0
195prompt_correct = 0
196instruction_total = 0
197instruction_correct = 0
198
199tier0_total = collections.defaultdict(int)
200tier0_correct = collections.defaultdict(int)
201
202tier1_total = collections.defaultdict(int)
203tier1_correct = collections.defaultdict(int)
204
205for example in outputs:
206follow_instruction_list = example.follow_instruction_list
207instruction_id_list = example.instruction_id_list
208
209prompt_total += 1
210if all(follow_instruction_list):
211prompt_correct += 1
212
213instruction_total += len(instruction_id_list)
214instruction_correct += sum(follow_instruction_list)
215
216for instruction_id, followed_or_not in zip(
217instruction_id_list, follow_instruction_list
218):
219instruction_id = instruction_id.split(":")[0]
220tier0_total[instruction_id] += 1
221if followed_or_not:
222tier0_correct[instruction_id] += 1
223
224for instruction_id, followed_or_not in zip(
225instruction_id_list, follow_instruction_list
226):
227tier1_total[instruction_id] += 1
228if followed_or_not:
229tier1_correct[instruction_id] += 1
230
231print(f"prompt-level: {prompt_correct / prompt_total}")
232print(f"instruction-level: {instruction_correct / instruction_total}")
233print()
234for instruction_id in sorted(tier0_total.keys()):
235accuracy = tier0_correct[instruction_id] / tier0_total[instruction_id]
236print(f"{instruction_id} {accuracy}")
237print()
238for instruction_id in sorted(tier1_total.keys()):
239accuracy = tier1_correct[instruction_id] / tier1_total[instruction_id]
240print(f"{instruction_id} {accuracy}")
241
242
243def main(argv):
244if len(argv) > 1:
245raise app.UsageError("Too many command-line arguments.")
246
247inputs = read_prompt_list(_INPUT_DATA.value)
248prompt_to_response = read_prompt_to_response_dict(
249_INPUT_RESPONSE_DATA.value)
250
251# get instruction following results
252for func, output_file_name in [
253(test_instruction_following_strict, "eval_results_strict"),
254(test_instruction_following_loose, "eval_results_loose"),
255]:
256logging.info("Generating %s...", output_file_name)
257outputs = []
258for inp in inputs:
259outputs.append(func(inp, prompt_to_response))
260follow_all_instructions = [o.follow_all_instructions for o in outputs]
261accuracy = sum(follow_all_instructions) / len(outputs)
262logging.info("Accuracy: %f", accuracy)
263
264output_file_name = os.path.join(
265_OUTPUT_DIR.value, output_file_name + ".jsonl"
266)
267write_outputs(output_file_name, outputs)
268logging.info("Generated: %s", output_file_name)
269
270# Prints instruction following accuracy report.
271print("=" * 64)
272print(f"{output_file_name} Accuracy Scores:")
273print_report(outputs)
274
275
276if __name__ == "__main__":
277app.run(main)
278