alignment-handbook
196 строк · 8.4 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15import unittest
16from copy import deepcopy
17
18import pytest
19from datasets import Dataset
20from transformers import AutoTokenizer
21
22from alignment import DataArguments, ModelArguments, apply_chat_template, get_datasets, get_tokenizer
23from alignment.data import maybe_insert_system_message
24
25
26class GetDatasetsTest(unittest.TestCase):
27"""Each of these test datasets has 100 examples"""
28
29def test_loading_data_args(self):
30dataset_mixer = {
31"HuggingFaceH4/testing_alpaca_small": 0.5,
32"HuggingFaceH4/testing_self_instruct_small": 0.3,
33"HuggingFaceH4/testing_codealpaca_small": 0.2,
34}
35data_args = DataArguments(dataset_mixer=dataset_mixer)
36datasets = get_datasets(data_args)
37self.assertEqual(len(datasets["train"]), 100)
38self.assertEqual(len(datasets["test"]), 300)
39
40def test_loading_data_dict(self):
41dataset_mixer = {
42"HuggingFaceH4/testing_alpaca_small": 0.5,
43"HuggingFaceH4/testing_self_instruct_small": 0.3,
44"HuggingFaceH4/testing_codealpaca_small": 0.2,
45}
46datasets = get_datasets(dataset_mixer)
47self.assertEqual(len(datasets["train"]), 100)
48self.assertEqual(len(datasets["test"]), 300)
49
50def test_loading_with_unit_fractions(self):
51dataset_mixer = {
52"HuggingFaceH4/testing_alpaca_small": 1.0,
53"HuggingFaceH4/testing_self_instruct_small": 1.0,
54"HuggingFaceH4/testing_codealpaca_small": 1.0,
55}
56datasets = get_datasets(dataset_mixer)
57self.assertEqual(len(datasets["train"]), 300)
58self.assertEqual(len(datasets["test"]), 300)
59
60def test_loading_with_fractions_greater_than_unity(self):
61dataset_mixer = {
62"HuggingFaceH4/testing_alpaca_small": 0.7,
63"HuggingFaceH4/testing_self_instruct_small": 0.4,
64}
65datasets = get_datasets(dataset_mixer)
66self.assertEqual(len(datasets["train"]), 70 + 40)
67self.assertEqual(len(datasets["test"]), 200)
68
69def test_loading_fails_with_negative_fractions(self):
70dataset_mixer = {
71"HuggingFaceH4/testing_alpaca_small": 0.7,
72"HuggingFaceH4/testing_self_instruct_small": -0.3,
73}
74with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
75get_datasets(dataset_mixer)
76
77def test_loading_single_split_with_unit_fractions(self):
78dataset_mixer = {
79"HuggingFaceH4/testing_alpaca_small": 1.0,
80}
81datasets = get_datasets(dataset_mixer, splits=["test"])
82self.assertEqual(len(datasets["test"]), 100)
83self.assertRaises(KeyError, lambda: datasets["train"])
84
85
86class ApplyChatTemplateTest(unittest.TestCase):
87def setUp(self):
88model_args = ModelArguments(model_name_or_path="HuggingFaceH4/zephyr-7b-alpha")
89data_args = DataArguments()
90self.tokenizer = get_tokenizer(model_args, data_args)
91self.dataset = Dataset.from_dict(
92{
93"prompt": ["Hello!"],
94"messages": [
95[
96{"role": "system", "content": "You are a happy chatbot"},
97{"role": "user", "content": "Hello!"},
98{"role": "assistant", "content": "Bonjour!"},
99{"role": "user", "content": "How are you?"},
100{"role": "assistant", "content": "I am doing well, thanks!"},
101]
102],
103"chosen": [
104[
105{"role": "system", "content": "You are a happy chatbot"},
106{"role": "user", "content": "Hello!"},
107{"role": "assistant", "content": "Bonjour!"},
108{"role": "user", "content": "How are you?"},
109{"role": "assistant", "content": "I am doing well, thanks!"},
110]
111],
112"rejected": [
113[
114{"role": "system", "content": "You are a happy chatbot"},
115{"role": "user", "content": "Hello!"},
116{"role": "assistant", "content": "Bonjour!"},
117{"role": "user", "content": "How are you?"},
118{"role": "assistant", "content": "Not so good tbh"},
119]
120],
121}
122)
123
124def test_maybe_insert_system_message(self):
125# does not accept system prompt
126mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
127# accepts system prompt. use codellama since it has no HF token reqiurement
128llama_tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
129messages_sys_excl = [{"role": "user", "content": "Tell me a joke."}]
130messages_sys_incl = [{"role": "system", "content": ""}, {"role": "user", "content": "Tell me a joke."}]
131
132mistral_messages = deepcopy(messages_sys_excl)
133llama_messages = deepcopy(messages_sys_excl)
134maybe_insert_system_message(mistral_messages, mistral_tokenizer)
135maybe_insert_system_message(llama_messages, llama_tokenizer)
136
137# output from mistral should not have a system message, output from llama should
138self.assertEqual(mistral_messages, messages_sys_excl)
139self.assertEqual(llama_messages, messages_sys_incl)
140
141def test_sft(self):
142dataset = self.dataset.map(
143apply_chat_template,
144fn_kwargs={"tokenizer": self.tokenizer, "task": "sft"},
145remove_columns=self.dataset.column_names,
146)
147self.assertDictEqual(
148dataset[0],
149{
150"text": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\nI am doing well, thanks!</s>\n"
151},
152)
153
154def test_generation(self):
155# Remove last turn from messages
156dataset = self.dataset.map(lambda x: {"messages": x["messages"][:-1]})
157dataset = dataset.map(
158apply_chat_template,
159fn_kwargs={"tokenizer": self.tokenizer, "task": "generation"},
160remove_columns=self.dataset.column_names,
161)
162self.assertDictEqual(
163dataset[0],
164{
165"text": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\n"
166},
167)
168
169def test_rm(self):
170dataset = self.dataset.map(
171apply_chat_template,
172fn_kwargs={"tokenizer": self.tokenizer, "task": "rm"},
173remove_columns=self.dataset.column_names,
174)
175self.assertDictEqual(
176dataset[0],
177{
178"text_chosen": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\nI am doing well, thanks!</s>\n",
179"text_rejected": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n<|assistant|>\nNot so good tbh</s>\n",
180},
181)
182
183def test_dpo(self):
184dataset = self.dataset.map(
185apply_chat_template,
186fn_kwargs={"tokenizer": self.tokenizer, "task": "dpo"},
187remove_columns=self.dataset.column_names,
188)
189self.assertDictEqual(
190dataset[0],
191{
192"text_prompt": "<|system|>\nYou are a happy chatbot</s>\n<|user|>\nHello!</s>\n<|assistant|>\nBonjour!</s>\n<|user|>\nHow are you?</s>\n",
193"text_chosen": "<|assistant|>\nI am doing well, thanks!</s>\n",
194"text_rejected": "<|assistant|>\nNot so good tbh</s>\n",
195},
196)
197