1
# Copyright 2023-present the HuggingFace Inc. team.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
21
AutoPeftModelForCausalLM,
22
AutoPeftModelForFeatureExtraction,
23
AutoPeftModelForQuestionAnswering,
24
AutoPeftModelForSeq2SeqLM,
25
AutoPeftModelForSequenceClassification,
26
AutoPeftModelForTokenClassification,
29
PeftModelForFeatureExtraction,
30
PeftModelForQuestionAnswering,
31
PeftModelForSeq2SeqLM,
32
PeftModelForSequenceClassification,
33
PeftModelForTokenClassification,
35
from peft.utils import infer_device
38
class PeftAutoModelTester(unittest.TestCase):
39
dtype = torch.float16 if infer_device() == "mps" else torch.bfloat16
41
def test_peft_causal_lm(self):
42
model_id = "peft-internal-testing/tiny-OPTForCausalLM-lora"
43
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
44
assert isinstance(model, PeftModelForCausalLM)
46
with tempfile.TemporaryDirectory() as tmp_dirname:
47
model.save_pretrained(tmp_dirname)
49
model = AutoPeftModelForCausalLM.from_pretrained(tmp_dirname)
50
assert isinstance(model, PeftModelForCausalLM)
52
# check if kwargs are passed correctly
53
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
54
assert isinstance(model, PeftModelForCausalLM)
55
assert model.base_model.lm_head.weight.dtype == self.dtype
57
adapter_name = "default"
60
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
62
def test_peft_causal_lm_extended_vocab(self):
63
model_id = "peft-internal-testing/tiny-random-OPTForCausalLM-extended-vocab"
64
model = AutoPeftModelForCausalLM.from_pretrained(model_id)
65
assert isinstance(model, PeftModelForCausalLM)
67
# check if kwargs are passed correctly
68
model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
69
assert isinstance(model, PeftModelForCausalLM)
70
assert model.base_model.lm_head.weight.dtype == self.dtype
72
adapter_name = "default"
75
_ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
77
def test_peft_seq2seq_lm(self):
78
model_id = "peft-internal-testing/tiny_T5ForSeq2SeqLM-lora"
79
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id)
80
assert isinstance(model, PeftModelForSeq2SeqLM)
82
with tempfile.TemporaryDirectory() as tmp_dirname:
83
model.save_pretrained(tmp_dirname)
85
model = AutoPeftModelForSeq2SeqLM.from_pretrained(tmp_dirname)
86
assert isinstance(model, PeftModelForSeq2SeqLM)
88
# check if kwargs are passed correctly
89
model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=self.dtype)
90
assert isinstance(model, PeftModelForSeq2SeqLM)
91
assert model.base_model.lm_head.weight.dtype == self.dtype
93
adapter_name = "default"
96
_ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
98
def test_peft_sequence_cls(self):
99
model_id = "peft-internal-testing/tiny_OPTForSequenceClassification-lora"
100
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id)
101
assert isinstance(model, PeftModelForSequenceClassification)
103
with tempfile.TemporaryDirectory() as tmp_dirname:
104
model.save_pretrained(tmp_dirname)
106
model = AutoPeftModelForSequenceClassification.from_pretrained(tmp_dirname)
107
assert isinstance(model, PeftModelForSequenceClassification)
109
# check if kwargs are passed correctly
110
model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=self.dtype)
111
assert isinstance(model, PeftModelForSequenceClassification)
112
assert model.score.original_module.weight.dtype == self.dtype
114
adapter_name = "default"
117
_ = AutoPeftModelForSequenceClassification.from_pretrained(
118
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
121
def test_peft_token_classification(self):
122
model_id = "peft-internal-testing/tiny_GPT2ForTokenClassification-lora"
123
model = AutoPeftModelForTokenClassification.from_pretrained(model_id)
124
assert isinstance(model, PeftModelForTokenClassification)
126
with tempfile.TemporaryDirectory() as tmp_dirname:
127
model.save_pretrained(tmp_dirname)
129
model = AutoPeftModelForTokenClassification.from_pretrained(tmp_dirname)
130
assert isinstance(model, PeftModelForTokenClassification)
132
# check if kwargs are passed correctly
133
model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=self.dtype)
134
assert isinstance(model, PeftModelForTokenClassification)
135
assert model.base_model.classifier.original_module.weight.dtype == self.dtype
137
adapter_name = "default"
140
_ = AutoPeftModelForTokenClassification.from_pretrained(
141
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
144
def test_peft_question_answering(self):
145
model_id = "peft-internal-testing/tiny_OPTForQuestionAnswering-lora"
146
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id)
147
assert isinstance(model, PeftModelForQuestionAnswering)
149
with tempfile.TemporaryDirectory() as tmp_dirname:
150
model.save_pretrained(tmp_dirname)
152
model = AutoPeftModelForQuestionAnswering.from_pretrained(tmp_dirname)
153
assert isinstance(model, PeftModelForQuestionAnswering)
155
# check if kwargs are passed correctly
156
model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=self.dtype)
157
assert isinstance(model, PeftModelForQuestionAnswering)
158
assert model.base_model.qa_outputs.original_module.weight.dtype == self.dtype
160
adapter_name = "default"
163
_ = AutoPeftModelForQuestionAnswering.from_pretrained(
164
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
167
def test_peft_feature_extraction(self):
168
model_id = "peft-internal-testing/tiny_OPTForFeatureExtraction-lora"
169
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id)
170
assert isinstance(model, PeftModelForFeatureExtraction)
172
with tempfile.TemporaryDirectory() as tmp_dirname:
173
model.save_pretrained(tmp_dirname)
175
model = AutoPeftModelForFeatureExtraction.from_pretrained(tmp_dirname)
176
assert isinstance(model, PeftModelForFeatureExtraction)
178
# check if kwargs are passed correctly
179
model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=self.dtype)
180
assert isinstance(model, PeftModelForFeatureExtraction)
181
assert model.base_model.model.decoder.embed_tokens.weight.dtype == self.dtype
183
adapter_name = "default"
186
_ = AutoPeftModelForFeatureExtraction.from_pretrained(
187
model_id, adapter_name, is_trainable, torch_dtype=self.dtype
190
def test_peft_whisper(self):
191
model_id = "peft-internal-testing/tiny_WhisperForConditionalGeneration-lora"
192
model = AutoPeftModel.from_pretrained(model_id)
193
assert isinstance(model, PeftModel)
195
with tempfile.TemporaryDirectory() as tmp_dirname:
196
model.save_pretrained(tmp_dirname)
198
model = AutoPeftModel.from_pretrained(tmp_dirname)
199
assert isinstance(model, PeftModel)
201
# check if kwargs are passed correctly
202
model = AutoPeftModel.from_pretrained(model_id, torch_dtype=self.dtype)
203
assert isinstance(model, PeftModel)
204
assert model.base_model.model.model.encoder.embed_positions.weight.dtype == self.dtype
206
adapter_name = "default"
209
_ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)