peft

Форк
0
/
test_auto.py 
209 строк · 8.8 Кб
1
# Copyright 2023-present the HuggingFace Inc. team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import tempfile
15
import unittest
16

17
import torch
18

19
from peft import (
20
    AutoPeftModel,
21
    AutoPeftModelForCausalLM,
22
    AutoPeftModelForFeatureExtraction,
23
    AutoPeftModelForQuestionAnswering,
24
    AutoPeftModelForSeq2SeqLM,
25
    AutoPeftModelForSequenceClassification,
26
    AutoPeftModelForTokenClassification,
27
    PeftModel,
28
    PeftModelForCausalLM,
29
    PeftModelForFeatureExtraction,
30
    PeftModelForQuestionAnswering,
31
    PeftModelForSeq2SeqLM,
32
    PeftModelForSequenceClassification,
33
    PeftModelForTokenClassification,
34
)
35
from peft.utils import infer_device
36

37

38
class PeftAutoModelTester(unittest.TestCase):
39
    dtype = torch.float16 if infer_device() == "mps" else torch.bfloat16
40

41
    def test_peft_causal_lm(self):
42
        model_id = "peft-internal-testing/tiny-OPTForCausalLM-lora"
43
        model = AutoPeftModelForCausalLM.from_pretrained(model_id)
44
        assert isinstance(model, PeftModelForCausalLM)
45

46
        with tempfile.TemporaryDirectory() as tmp_dirname:
47
            model.save_pretrained(tmp_dirname)
48

49
            model = AutoPeftModelForCausalLM.from_pretrained(tmp_dirname)
50
            assert isinstance(model, PeftModelForCausalLM)
51

52
        # check if kwargs are passed correctly
53
        model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
54
        assert isinstance(model, PeftModelForCausalLM)
55
        assert model.base_model.lm_head.weight.dtype == self.dtype
56

57
        adapter_name = "default"
58
        is_trainable = False
59
        # This should work
60
        _ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
61

62
    def test_peft_causal_lm_extended_vocab(self):
63
        model_id = "peft-internal-testing/tiny-random-OPTForCausalLM-extended-vocab"
64
        model = AutoPeftModelForCausalLM.from_pretrained(model_id)
65
        assert isinstance(model, PeftModelForCausalLM)
66

67
        # check if kwargs are passed correctly
68
        model = AutoPeftModelForCausalLM.from_pretrained(model_id, torch_dtype=self.dtype)
69
        assert isinstance(model, PeftModelForCausalLM)
70
        assert model.base_model.lm_head.weight.dtype == self.dtype
71

72
        adapter_name = "default"
73
        is_trainable = False
74
        # This should work
75
        _ = AutoPeftModelForCausalLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
76

77
    def test_peft_seq2seq_lm(self):
78
        model_id = "peft-internal-testing/tiny_T5ForSeq2SeqLM-lora"
79
        model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id)
80
        assert isinstance(model, PeftModelForSeq2SeqLM)
81

82
        with tempfile.TemporaryDirectory() as tmp_dirname:
83
            model.save_pretrained(tmp_dirname)
84

85
            model = AutoPeftModelForSeq2SeqLM.from_pretrained(tmp_dirname)
86
            assert isinstance(model, PeftModelForSeq2SeqLM)
87

88
        # check if kwargs are passed correctly
89
        model = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, torch_dtype=self.dtype)
90
        assert isinstance(model, PeftModelForSeq2SeqLM)
91
        assert model.base_model.lm_head.weight.dtype == self.dtype
92

93
        adapter_name = "default"
94
        is_trainable = False
95
        # This should work
96
        _ = AutoPeftModelForSeq2SeqLM.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
97

98
    def test_peft_sequence_cls(self):
99
        model_id = "peft-internal-testing/tiny_OPTForSequenceClassification-lora"
100
        model = AutoPeftModelForSequenceClassification.from_pretrained(model_id)
101
        assert isinstance(model, PeftModelForSequenceClassification)
102

103
        with tempfile.TemporaryDirectory() as tmp_dirname:
104
            model.save_pretrained(tmp_dirname)
105

106
            model = AutoPeftModelForSequenceClassification.from_pretrained(tmp_dirname)
107
            assert isinstance(model, PeftModelForSequenceClassification)
108

109
        # check if kwargs are passed correctly
110
        model = AutoPeftModelForSequenceClassification.from_pretrained(model_id, torch_dtype=self.dtype)
111
        assert isinstance(model, PeftModelForSequenceClassification)
112
        assert model.score.original_module.weight.dtype == self.dtype
113

114
        adapter_name = "default"
115
        is_trainable = False
116
        # This should work
117
        _ = AutoPeftModelForSequenceClassification.from_pretrained(
118
            model_id, adapter_name, is_trainable, torch_dtype=self.dtype
119
        )
120

121
    def test_peft_token_classification(self):
122
        model_id = "peft-internal-testing/tiny_GPT2ForTokenClassification-lora"
123
        model = AutoPeftModelForTokenClassification.from_pretrained(model_id)
124
        assert isinstance(model, PeftModelForTokenClassification)
125

126
        with tempfile.TemporaryDirectory() as tmp_dirname:
127
            model.save_pretrained(tmp_dirname)
128

129
            model = AutoPeftModelForTokenClassification.from_pretrained(tmp_dirname)
130
            assert isinstance(model, PeftModelForTokenClassification)
131

132
        # check if kwargs are passed correctly
133
        model = AutoPeftModelForTokenClassification.from_pretrained(model_id, torch_dtype=self.dtype)
134
        assert isinstance(model, PeftModelForTokenClassification)
135
        assert model.base_model.classifier.original_module.weight.dtype == self.dtype
136

137
        adapter_name = "default"
138
        is_trainable = False
139
        # This should work
140
        _ = AutoPeftModelForTokenClassification.from_pretrained(
141
            model_id, adapter_name, is_trainable, torch_dtype=self.dtype
142
        )
143

144
    def test_peft_question_answering(self):
145
        model_id = "peft-internal-testing/tiny_OPTForQuestionAnswering-lora"
146
        model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id)
147
        assert isinstance(model, PeftModelForQuestionAnswering)
148

149
        with tempfile.TemporaryDirectory() as tmp_dirname:
150
            model.save_pretrained(tmp_dirname)
151

152
            model = AutoPeftModelForQuestionAnswering.from_pretrained(tmp_dirname)
153
            assert isinstance(model, PeftModelForQuestionAnswering)
154

155
        # check if kwargs are passed correctly
156
        model = AutoPeftModelForQuestionAnswering.from_pretrained(model_id, torch_dtype=self.dtype)
157
        assert isinstance(model, PeftModelForQuestionAnswering)
158
        assert model.base_model.qa_outputs.original_module.weight.dtype == self.dtype
159

160
        adapter_name = "default"
161
        is_trainable = False
162
        # This should work
163
        _ = AutoPeftModelForQuestionAnswering.from_pretrained(
164
            model_id, adapter_name, is_trainable, torch_dtype=self.dtype
165
        )
166

167
    def test_peft_feature_extraction(self):
168
        model_id = "peft-internal-testing/tiny_OPTForFeatureExtraction-lora"
169
        model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id)
170
        assert isinstance(model, PeftModelForFeatureExtraction)
171

172
        with tempfile.TemporaryDirectory() as tmp_dirname:
173
            model.save_pretrained(tmp_dirname)
174

175
            model = AutoPeftModelForFeatureExtraction.from_pretrained(tmp_dirname)
176
            assert isinstance(model, PeftModelForFeatureExtraction)
177

178
        # check if kwargs are passed correctly
179
        model = AutoPeftModelForFeatureExtraction.from_pretrained(model_id, torch_dtype=self.dtype)
180
        assert isinstance(model, PeftModelForFeatureExtraction)
181
        assert model.base_model.model.decoder.embed_tokens.weight.dtype == self.dtype
182

183
        adapter_name = "default"
184
        is_trainable = False
185
        # This should work
186
        _ = AutoPeftModelForFeatureExtraction.from_pretrained(
187
            model_id, adapter_name, is_trainable, torch_dtype=self.dtype
188
        )
189

190
    def test_peft_whisper(self):
191
        model_id = "peft-internal-testing/tiny_WhisperForConditionalGeneration-lora"
192
        model = AutoPeftModel.from_pretrained(model_id)
193
        assert isinstance(model, PeftModel)
194

195
        with tempfile.TemporaryDirectory() as tmp_dirname:
196
            model.save_pretrained(tmp_dirname)
197

198
            model = AutoPeftModel.from_pretrained(tmp_dirname)
199
            assert isinstance(model, PeftModel)
200

201
        # check if kwargs are passed correctly
202
        model = AutoPeftModel.from_pretrained(model_id, torch_dtype=self.dtype)
203
        assert isinstance(model, PeftModel)
204
        assert model.base_model.model.model.encoder.embed_positions.weight.dtype == self.dtype
205

206
        adapter_name = "default"
207
        is_trainable = False
208
        # This should work
209
        _ = AutoPeftModel.from_pretrained(model_id, adapter_name, is_trainable, torch_dtype=self.dtype)
210

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.