18
from unittest import TestCase
22
from parameterized import parameterized
23
from torch.testing import assert_close
25
from peft.mapping import get_peft_model
26
from peft.peft_model import PeftModel
27
from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit
28
from peft.utils.other import WEIGHTS_NAME, prepare_model_for_int8_training
29
from peft.utils.save_and_load import get_peft_model_state_dict
30
from tests.testing_common import PeftCommonTester
33
def is_llama_available() -> bool:
34
"""Check if Llama is available in the transformers library (it's not in earlier versions)."""
36
return importlib.util.find_spec("transformers.models.llama.modeling_llama") is not None
37
except ModuleNotFoundError:
41
if is_llama_available():
44
from transformers import LlamaConfig, LlamaForCausalLM
47
class MultiTaskPromptTuningTester(TestCase, PeftCommonTester):
49
Tests for the AdaptionPrompt model.
51
Some of these tests were adapted from `test_peft_model.py` (which has been refactored since), but since we haven't
52
checked in the test checkpoints for Llama into `hf-internal-testing`, we separate them for now.
56
"""Check that llama is available in transformers package before running each test."""
57
if not is_llama_available():
58
self.skipTest("Llama not available in transformers. Skipping test.")
61
def _create_test_llama_config():
62
"""Create a test config for a small Llama model for testing."""
68
num_attention_heads=4,
73
def _create_multitask_prompt_tuning_config(cls) -> MultitaskPromptTuningConfig:
74
return MultitaskPromptTuningConfig(
75
task_type="CAUSAL_LM",
76
num_virtual_tokens=50,
78
prompt_tuning_init_text=(
79
"classify the following into either positive or negative, or entailment, neutral or contradiction:"
83
def test_prepare_for_training(self) -> None:
84
model = LlamaForCausalLM(self._create_test_llama_config())
85
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
86
model = model.to(self.torch_device)
88
dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
89
dummy_output = model.get_input_embeddings()(dummy_input)
91
assert not dummy_output.requires_grad
93
def test_prepare_for_int8_training(self) -> None:
94
model = LlamaForCausalLM(self._create_test_llama_config())
95
model = prepare_model_for_int8_training(model)
96
model = model.to(self.torch_device)
98
for param in model.parameters():
99
assert not param.requires_grad
101
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
104
if hasattr(model, "enable_input_require_grads"):
105
model.enable_input_require_grads()
108
def make_inputs_require_grad(module, input, output):
109
output.requires_grad_(True)
111
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
113
dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
114
dummy_output = model.get_input_embeddings()(dummy_input)
116
assert dummy_output.requires_grad
118
def test_save_pretrained(self) -> None:
120
torch.manual_seed(seed)
121
model = LlamaForCausalLM(self._create_test_llama_config())
122
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
123
model = model.to(self.torch_device)
125
with tempfile.TemporaryDirectory() as tmp_dirname:
126
model.save_pretrained(tmp_dirname)
128
torch.manual_seed(seed)
129
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
130
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
133
state_dict = get_peft_model_state_dict(model)
135
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
138
assert state_dict.keys() == state_dict_from_pretrained.keys()
141
assert len(state_dict) == 3
144
for key in state_dict.keys():
145
assert torch.allclose(
146
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
150
assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))
153
assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
156
assert not os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))
159
assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
161
def test_save_pretrained_regression(self) -> None:
163
torch.manual_seed(seed)
164
model = LlamaForCausalLM(self._create_test_llama_config())
165
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
166
model = model.to(self.torch_device)
168
with tempfile.TemporaryDirectory() as tmp_dirname:
169
model.save_pretrained(tmp_dirname, safe_serialization=False)
171
torch.manual_seed(seed)
172
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
173
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
176
state_dict = get_peft_model_state_dict(model)
178
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
181
assert state_dict.keys() == state_dict_from_pretrained.keys()
184
assert len(state_dict) == 3
187
for key in state_dict.keys():
188
assert torch.allclose(
189
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
193
assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))
196
assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
199
assert not os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))
202
assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
204
def test_generate(self) -> None:
205
model = LlamaForCausalLM(self._create_test_llama_config())
206
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
207
model = model.to(self.torch_device)
209
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
210
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
211
task_ids = torch.LongTensor([1, 2]).to(self.torch_device)
214
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
217
_ = model.generate(input_ids, attention_mask=attention_mask, task_ids=task_ids)
219
def test_use_cache(self) -> None:
220
"""Test that MultiTaskPromptTuning works when Llama config use_cache=True."""
222
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
223
task_ids = torch.LongTensor([1, 2]).to(self.torch_device)
225
original = LlamaForCausalLM(self._create_test_llama_config()).eval()
226
mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
227
mpt = mpt.to(self.torch_device)
229
expected = mpt.generate(input_ids=input_ids, max_length=8, task_ids=task_ids)
232
mpt.base_model.config.use_cache = True
233
actual = mpt.generate(input_ids=input_ids, max_length=8, task_ids=task_ids)
234
assert_close(expected, actual, rtol=0, atol=0)
236
def test_bf16_inference(self) -> None:
237
"""Test that MultiTaskPromptTuning works when Llama using a half-precision model."""
238
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
239
task_ids = torch.tensor([1, 2]).to(self.torch_device)
241
original = LlamaForCausalLM.from_pretrained(
242
"trl-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16
244
mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
245
mpt = mpt.to(self.torch_device)
246
_ = mpt.generate(input_ids=input_ids, task_ids=task_ids)
248
def test_generate_text_with_random_init(self) -> None:
249
model = LlamaForCausalLM(self._create_test_llama_config())
251
config = self._create_multitask_prompt_tuning_config()
252
config.prompt_tuning_init = MultitaskPromptTuningInit.RANDOM
254
model = get_peft_model(model, config)
255
model = model.to(self.torch_device)
257
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
258
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
259
task_ids = torch.LongTensor([0]).to(self.torch_device)
262
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
264
with pytest.raises(ValueError):
266
_ = model.generate(input_ids, attention_mask=attention_mask)
268
@parameterized.expand(
270
MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
271
MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
272
MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
275
def test_generate_text_with_other_init(self, prompt_tuning_init) -> None:
276
with tempfile.TemporaryDirectory() as tmp_dirname:
277
model = LlamaForCausalLM(self._create_test_llama_config())
278
model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
279
model.save_pretrained(tmp_dirname, safe_serialization=False)
281
config = MultitaskPromptTuningConfig(
282
task_type="CAUSAL_LM",
283
num_virtual_tokens=50,
285
prompt_tuning_init_text=(
286
"classify the following into either positive or negative, or entailment, neutral or contradiction:"
288
prompt_tuning_init=prompt_tuning_init,
289
prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME),
291
model = LlamaForCausalLM(self._create_test_llama_config())
292
model = get_peft_model(model, config)
293
model = model.to(self.torch_device)
295
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
296
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
297
task_ids = torch.LongTensor([0]).to(self.torch_device)
300
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
302
with pytest.raises(ValueError):
304
_ = model.generate(input_ids, attention_mask=attention_mask)