19
from unittest import TestCase
23
from torch.testing import assert_close
25
from peft.mapping import get_peft_model
26
from peft.peft_model import PeftModel
27
from peft.tuners.adaption_prompt import AdaptionPromptConfig
28
from peft.utils.other import prepare_model_for_int8_training
29
from peft.utils.save_and_load import get_peft_model_state_dict
30
from tests.testing_common import PeftCommonTester
33
def is_llama_available() -> bool:
34
"""Check if Llama is available in the transformers library (it's not in earlier versions)."""
36
return importlib.util.find_spec("transformers.models.llama.modeling_llama") is not None
37
except ModuleNotFoundError:
41
if is_llama_available():
44
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
47
class AdaptionPromptTester(TestCase, PeftCommonTester):
49
Tests for the AdaptionPrompt model.
51
Some of these tests were adapted from `test_peft_model.py` (which has been refactored since), but since we haven't
52
checked in the test checkpoints for Llama into `hf-internal-testing`, we separate them for now.
57
if not is_llama_available():
58
self.skipTest("Llama not available in transformers. Skipping test.")
61
def _create_test_llama_config():
62
"""Create a test config for a small Llama model for testing."""
68
num_attention_heads=4,
72
def test_attributes(self) -> None:
73
model = LlamaModel(self._create_test_llama_config())
74
config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4)
75
model = get_peft_model(model, config)
77
assert hasattr(model, "save_pretrained")
78
assert hasattr(model, "from_pretrained")
79
assert hasattr(model, "push_to_hub")
81
def test_prepare_for_training(self) -> None:
82
model = LlamaForCausalLM(self._create_test_llama_config())
83
config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM")
84
model = get_peft_model(model, config)
85
model = model.to(self.torch_device)
87
dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
88
dummy_output = model.get_input_embeddings()(dummy_input)
90
assert not dummy_output.requires_grad
92
def test_prepare_for_int8_training(self) -> None:
93
model = LlamaForCausalLM(self._create_test_llama_config())
94
model = prepare_model_for_int8_training(model)
95
model = model.to(self.torch_device)
97
for param in model.parameters():
98
assert not param.requires_grad
100
config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM")
101
model = get_peft_model(model, config)
104
if hasattr(model, "enable_input_require_grads"):
105
model.enable_input_require_grads()
108
def make_inputs_require_grad(module, input, output):
109
output.requires_grad_(True)
111
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
113
dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
114
dummy_output = model.get_input_embeddings()(dummy_input)
116
assert dummy_output.requires_grad
118
def test_save_pretrained_regression(self) -> None:
120
torch.manual_seed(seed)
121
model = LlamaForCausalLM(self._create_test_llama_config())
122
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
123
model = get_peft_model(model, config)
124
model = model.to(self.torch_device)
126
with tempfile.TemporaryDirectory() as tmp_dirname:
127
model.save_pretrained(tmp_dirname, safe_serialization=False)
129
torch.manual_seed(seed)
130
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
131
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
134
state_dict = get_peft_model_state_dict(model)
135
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
138
assert state_dict.keys() == state_dict_from_pretrained.keys()
141
assert len(state_dict) == 4
144
for key in state_dict.keys():
145
assert torch.allclose(
146
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
150
assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))
153
assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
156
assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))
159
assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
161
def test_save_pretrained(self) -> None:
163
torch.manual_seed(seed)
164
model = LlamaForCausalLM(self._create_test_llama_config())
165
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
166
model = get_peft_model(model, config)
167
model = model.to(self.torch_device)
169
with tempfile.TemporaryDirectory() as tmp_dirname:
170
model.save_pretrained(tmp_dirname)
172
torch.manual_seed(seed)
173
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
174
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
177
state_dict = get_peft_model_state_dict(model)
178
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
181
assert state_dict.keys() == state_dict_from_pretrained.keys()
184
assert len(state_dict) == 4
187
for key in state_dict.keys():
188
assert torch.allclose(
189
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
193
assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))
196
assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
199
assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))
202
assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
204
def test_save_pretrained_selected_adapters(self) -> None:
206
torch.manual_seed(seed)
207
model = LlamaForCausalLM(self._create_test_llama_config())
208
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
209
model = get_peft_model(model, config)
210
model = model.to(self.torch_device)
212
new_adapter_config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
213
model.add_adapter("new_adapter", new_adapter_config)
215
with tempfile.TemporaryDirectory() as tmp_dirname:
216
model.save_pretrained(tmp_dirname)
218
torch.manual_seed(seed)
219
model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
220
model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
222
model_from_pretrained.load_adapter(tmp_dirname, "new_adapter")
225
state_dict = get_peft_model_state_dict(model)
226
state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
229
assert state_dict.keys() == state_dict_from_pretrained.keys()
232
assert len(state_dict) == 4
235
for key in state_dict.keys():
236
assert torch.allclose(
237
state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
241
assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))
244
assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
247
assert not os.path.exists(os.path.join(tmp_dirname, "model.safetensors"))
250
assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
252
def test_generate(self) -> None:
253
model = LlamaForCausalLM(self._create_test_llama_config())
254
config = AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
255
model = get_peft_model(model, config)
256
model = model.to(self.torch_device)
258
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
259
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
262
_ = model.generate(input_ids=input_ids, attention_mask=attention_mask)
265
_ = model.generate(input_ids, attention_mask=attention_mask)
267
def test_sequence_adapter_ops(self) -> None:
268
"""Test sequence of adapter operations."""
270
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
271
target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(self.torch_device)
272
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
275
original = LlamaForCausalLM(self._create_test_llama_config())
276
original = original.to(self.torch_device)
277
original_before = original(input_ids=input_ids, attention_mask=attention_mask)
280
adapted = get_peft_model(
281
original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
283
adapted = adapted.to(self.torch_device)
284
default_before = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
287
assert_close(original_before.logits, default_before.logits, rtol=0, atol=0)
290
optimizer = torch.optim.SGD(adapted.parameters(), lr=1)
291
optimizer.zero_grad()
292
default_before.loss.backward()
296
default_after = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
297
assert not torch.allclose(default_before.logits, default_after.logits)
299
with adapted.disable_adapter():
301
default_disabled = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
302
assert_close(original_before.logits, default_disabled.logits, rtol=0, atol=0)
305
adapted.add_adapter("adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM"))
307
adapter_1_before = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
308
assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0)
311
optimizer = torch.optim.SGD(adapted.parameters(), lr=1)
312
optimizer.zero_grad()
313
adapter_1_before.loss.backward()
317
adapter_1_after = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
318
assert not torch.allclose(adapter_1_before.logits, adapter_1_after.logits)
319
assert not torch.allclose(original_before.logits, adapter_1_after.logits)
320
assert not torch.allclose(default_after.logits, adapter_1_after.logits)
322
with adapted.disable_adapter():
324
adapter_1_disabled = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
325
assert_close(original_before.logits, adapter_1_disabled.logits, rtol=0, atol=0)
328
adapted.set_adapter("default")
331
default_after_set = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
332
assert_close(default_after.logits, default_after_set.logits, rtol=0, atol=0)
333
assert not torch.allclose(original_before.logits, default_after_set.logits)
334
assert not torch.allclose(adapter_1_after.logits, default_after_set.logits)
336
def test_add_and_set_while_disabled(self):
337
"""Test that adding and setting adapters while disabled works as intended."""
339
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
340
target_ids = torch.LongTensor([[0, 0, 0], [0, 0, 0]]).to(self.torch_device)
341
attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
344
original = LlamaForCausalLM(self._create_test_llama_config())
345
original = original.to(self.torch_device)
346
original_before = original(input_ids=input_ids, attention_mask=attention_mask)
349
adapted = get_peft_model(
350
original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
352
adapted = adapted.to(self.torch_device)
354
with adapted.disable_adapter():
356
"adapter 1", AdaptionPromptConfig(adapter_layers=3, adapter_len=8, task_type="CAUSAL_LM")
360
adapter_1_before = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
361
assert_close(original_before.logits, adapter_1_before.logits, rtol=0, atol=0)
364
optimizer = torch.optim.SGD(adapted.parameters(), lr=1)
365
optimizer.zero_grad()
366
adapter_1_before.loss.backward()
370
adapter_1_after = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
371
assert not torch.allclose(original_before.logits, adapter_1_after.logits)
373
adapted.set_adapter("default")
374
with adapted.disable_adapter():
375
adapted.set_adapter("adapter 1")
378
adapter_1_after_set = adapted(input_ids=input_ids, attention_mask=attention_mask, labels=target_ids)
379
assert_close(adapter_1_after.logits, adapter_1_after_set.logits, rtol=0, atol=0)
381
def test_use_cache(self) -> None:
382
"""Test that AdaptionPrompt works when Llama config use_cache=True."""
384
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
385
original = LlamaForCausalLM(
391
num_attention_heads=4,
395
adapted = get_peft_model(
396
original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
398
adapted = adapted.to(self.torch_device)
399
expected = adapted.generate(input_ids=input_ids, max_length=8)
402
adapted.base_model.config.use_cache = True
403
actual = adapted.generate(input_ids=input_ids, max_length=8)
404
assert_close(expected, actual, rtol=0, atol=0)
406
def test_bf16_inference(self) -> None:
407
if self.torch_device == "mps":
408
return pytest.skip("Skipping bf16 test on MPS")
410
"""Test that AdaptionPrompt works when Llama using a half-precision model."""
411
input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
412
original = LlamaForCausalLM.from_pretrained(
413
"trl-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16
415
adapted = get_peft_model(
416
original, AdaptionPromptConfig(adapter_layers=2, adapter_len=4, task_type="CAUSAL_LM")
418
adapted = adapted.to(self.torch_device)
419
_ = adapted.generate(input_ids=input_ids)
421
@unittest.expectedFailure
422
def test_disable_adapter(self):
423
llama_config = self._create_test_llama_config()
424
model = LlamaForCausalLM(llama_config).to(self.torch_device)
425
dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
426
output_before = model(dummy_input).logits
428
config = AdaptionPromptConfig(adapter_layers=1, adapter_len=4, task_type="CAUSAL_LM")
429
model = get_peft_model(model, config).to(self.torch_device)
430
output_peft = model(dummy_input).logits
435
assert not torch.allclose(output_before, output_peft)
437
with model.disable_adapter():
438
output_peft_disabled = model(dummy_input).logits
439
assert torch.allclose(output_before, output_peft_disabled)