peft

Форк
0
/
test_multitask_prompt_tuning.py 
304 строки · 12.9 Кб
1
# Copyright 2023-present the HuggingFace Inc. team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import importlib
16
import os
17
import tempfile
18
from unittest import TestCase
19

20
import pytest
21
import torch
22
from parameterized import parameterized
23
from torch.testing import assert_close
24

25
from peft.mapping import get_peft_model
26
from peft.peft_model import PeftModel
27
from peft.tuners.multitask_prompt_tuning import MultitaskPromptTuningConfig, MultitaskPromptTuningInit
28
from peft.utils.other import WEIGHTS_NAME, prepare_model_for_int8_training
29
from peft.utils.save_and_load import get_peft_model_state_dict
30
from tests.testing_common import PeftCommonTester
31

32

33
def is_llama_available() -> bool:
34
    """Check if Llama is available in the transformers library (it's not in earlier versions)."""
35
    try:
36
        return importlib.util.find_spec("transformers.models.llama.modeling_llama") is not None
37
    except ModuleNotFoundError:
38
        return False
39

40

41
if is_llama_available():
42
    # We guard the import statement so that our unit tests will pass in CI environments
43
    # that don't have a transformers package with Llama.
44
    from transformers import LlamaConfig, LlamaForCausalLM
45

46

47
class MultiTaskPromptTuningTester(TestCase, PeftCommonTester):
48
    """
49
    Tests for the AdaptionPrompt model.
50

51
    Some of these tests were adapted from `test_peft_model.py` (which has been refactored since), but since we haven't
52
    checked in the test checkpoints for Llama into `hf-internal-testing`, we separate them for now.
53
    """
54

55
    def setUp(self):
56
        """Check that llama is available in transformers package before running each test."""
57
        if not is_llama_available():
58
            self.skipTest("Llama not available in transformers. Skipping test.")
59

60
    @staticmethod
61
    def _create_test_llama_config():
62
        """Create a test config for a small Llama model for testing."""
63
        return LlamaConfig(
64
            vocab_size=16,
65
            hidden_size=8,
66
            intermediate_size=8,
67
            num_hidden_layers=8,
68
            num_attention_heads=4,
69
            use_cache=False,
70
        )
71

72
    @classmethod
73
    def _create_multitask_prompt_tuning_config(cls) -> MultitaskPromptTuningConfig:
74
        return MultitaskPromptTuningConfig(
75
            task_type="CAUSAL_LM",
76
            num_virtual_tokens=50,
77
            num_tasks=3,
78
            prompt_tuning_init_text=(
79
                "classify the following into either positive or negative, or entailment, neutral or contradiction:"
80
            ),
81
        )
82

83
    def test_prepare_for_training(self) -> None:
84
        model = LlamaForCausalLM(self._create_test_llama_config())
85
        model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
86
        model = model.to(self.torch_device)
87

88
        dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
89
        dummy_output = model.get_input_embeddings()(dummy_input)
90

91
        assert not dummy_output.requires_grad
92

93
    def test_prepare_for_int8_training(self) -> None:
94
        model = LlamaForCausalLM(self._create_test_llama_config())
95
        model = prepare_model_for_int8_training(model)
96
        model = model.to(self.torch_device)
97

98
        for param in model.parameters():
99
            assert not param.requires_grad
100

101
        model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
102

103
        # For backward compatibility
104
        if hasattr(model, "enable_input_require_grads"):
105
            model.enable_input_require_grads()
106
        else:
107

108
            def make_inputs_require_grad(module, input, output):
109
                output.requires_grad_(True)
110

111
            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
112

113
        dummy_input = torch.LongTensor([[1, 1, 1]]).to(self.torch_device)
114
        dummy_output = model.get_input_embeddings()(dummy_input)
115

116
        assert dummy_output.requires_grad
117

118
    def test_save_pretrained(self) -> None:
119
        seed = 420
120
        torch.manual_seed(seed)
121
        model = LlamaForCausalLM(self._create_test_llama_config())
122
        model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
123
        model = model.to(self.torch_device)
124

125
        with tempfile.TemporaryDirectory() as tmp_dirname:
126
            model.save_pretrained(tmp_dirname)
127

128
            torch.manual_seed(seed)
129
            model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
130
            model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
131

132
            # check if the state dicts are equal
133
            state_dict = get_peft_model_state_dict(model)
134

135
            state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
136

137
            # check if same keys
138
            assert state_dict.keys() == state_dict_from_pretrained.keys()
139

140
            # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
141
            assert len(state_dict) == 3
142

143
            # check if tensors equal
144
            for key in state_dict.keys():
145
                assert torch.allclose(
146
                    state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
147
                )
148

149
            # check if `adapter_model.safetensors` is present
150
            assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.safetensors"))
151

152
            # check if `adapter_config.json` is present
153
            assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
154

155
            # check if `pytorch_model.bin` is not present
156
            assert not os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))
157

158
            # check if `config.json` is not present
159
            assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
160

161
    def test_save_pretrained_regression(self) -> None:
162
        seed = 420
163
        torch.manual_seed(seed)
164
        model = LlamaForCausalLM(self._create_test_llama_config())
165
        model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
166
        model = model.to(self.torch_device)
167

168
        with tempfile.TemporaryDirectory() as tmp_dirname:
169
            model.save_pretrained(tmp_dirname, safe_serialization=False)
170

171
            torch.manual_seed(seed)
172
            model_from_pretrained = LlamaForCausalLM(self._create_test_llama_config())
173
            model_from_pretrained = PeftModel.from_pretrained(model_from_pretrained, tmp_dirname)
174

175
            # check if the state dicts are equal
176
            state_dict = get_peft_model_state_dict(model)
177

178
            state_dict_from_pretrained = get_peft_model_state_dict(model_from_pretrained)
179

180
            # check if same keys
181
            assert state_dict.keys() == state_dict_from_pretrained.keys()
182

183
            # Check that the number of saved parameters is 4 -- 2 layers of (tokens and gate).
184
            assert len(state_dict) == 3
185

186
            # check if tensors equal
187
            for key in state_dict.keys():
188
                assert torch.allclose(
189
                    state_dict[key].to(self.torch_device), state_dict_from_pretrained[key].to(self.torch_device)
190
                )
191

192
            # check if `adapter_model.bin` is present for regression
193
            assert os.path.exists(os.path.join(tmp_dirname, "adapter_model.bin"))
194

195
            # check if `adapter_config.json` is present
196
            assert os.path.exists(os.path.join(tmp_dirname, "adapter_config.json"))
197

198
            # check if `pytorch_model.bin` is not present
199
            assert not os.path.exists(os.path.join(tmp_dirname, "pytorch_model.bin"))
200

201
            # check if `config.json` is not present
202
            assert not os.path.exists(os.path.join(tmp_dirname, "config.json"))
203

204
    def test_generate(self) -> None:
205
        model = LlamaForCausalLM(self._create_test_llama_config())
206
        model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
207
        model = model.to(self.torch_device)
208

209
        input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
210
        attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
211
        task_ids = torch.LongTensor([1, 2]).to(self.torch_device)
212

213
        # check if `generate` works
214
        _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
215

216
        # check if `generate` works if positional arguments are passed
217
        _ = model.generate(input_ids, attention_mask=attention_mask, task_ids=task_ids)
218

219
    def test_use_cache(self) -> None:
220
        """Test that MultiTaskPromptTuning works when Llama config use_cache=True."""
221
        torch.manual_seed(0)
222
        input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
223
        task_ids = torch.LongTensor([1, 2]).to(self.torch_device)
224

225
        original = LlamaForCausalLM(self._create_test_llama_config()).eval()
226
        mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
227
        mpt = mpt.to(self.torch_device)
228

229
        expected = mpt.generate(input_ids=input_ids, max_length=8, task_ids=task_ids)
230

231
        # Set use_cache = True and generate output again.
232
        mpt.base_model.config.use_cache = True
233
        actual = mpt.generate(input_ids=input_ids, max_length=8, task_ids=task_ids)
234
        assert_close(expected, actual, rtol=0, atol=0)
235

236
    def test_bf16_inference(self) -> None:
237
        """Test that MultiTaskPromptTuning works when Llama using a half-precision model."""
238
        input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
239
        task_ids = torch.tensor([1, 2]).to(self.torch_device)
240

241
        original = LlamaForCausalLM.from_pretrained(
242
            "trl-internal-testing/tiny-random-LlamaForCausalLM", torch_dtype=torch.bfloat16
243
        )
244
        mpt = get_peft_model(original, self._create_multitask_prompt_tuning_config())
245
        mpt = mpt.to(self.torch_device)
246
        _ = mpt.generate(input_ids=input_ids, task_ids=task_ids)
247

248
    def test_generate_text_with_random_init(self) -> None:
249
        model = LlamaForCausalLM(self._create_test_llama_config())
250

251
        config = self._create_multitask_prompt_tuning_config()
252
        config.prompt_tuning_init = MultitaskPromptTuningInit.RANDOM
253

254
        model = get_peft_model(model, config)
255
        model = model.to(self.torch_device)
256

257
        input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
258
        attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
259
        task_ids = torch.LongTensor([0]).to(self.torch_device)
260

261
        # check if `generate` works
262
        _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
263

264
        with pytest.raises(ValueError):
265
            # check if `generate` raises an error if task_ids are not passed
266
            _ = model.generate(input_ids, attention_mask=attention_mask)
267

268
    @parameterized.expand(
269
        [
270
            MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
271
            MultitaskPromptTuningInit.EXACT_SOURCE_TASK,
272
            MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
273
        ],
274
    )
275
    def test_generate_text_with_other_init(self, prompt_tuning_init) -> None:
276
        with tempfile.TemporaryDirectory() as tmp_dirname:
277
            model = LlamaForCausalLM(self._create_test_llama_config())
278
            model = get_peft_model(model, self._create_multitask_prompt_tuning_config())
279
            model.save_pretrained(tmp_dirname, safe_serialization=False)  # bc torch.load is used
280

281
            config = MultitaskPromptTuningConfig(
282
                task_type="CAUSAL_LM",
283
                num_virtual_tokens=50,
284
                num_tasks=1,
285
                prompt_tuning_init_text=(
286
                    "classify the following into either positive or negative, or entailment, neutral or contradiction:"
287
                ),
288
                prompt_tuning_init=prompt_tuning_init,
289
                prompt_tuning_init_state_dict_path=os.path.join(tmp_dirname, WEIGHTS_NAME),
290
            )
291
            model = LlamaForCausalLM(self._create_test_llama_config())
292
            model = get_peft_model(model, config)
293
            model = model.to(self.torch_device)
294

295
            input_ids = torch.LongTensor([[1, 1, 1], [2, 1, 2]]).to(self.torch_device)
296
            attention_mask = torch.LongTensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
297
            task_ids = torch.LongTensor([0]).to(self.torch_device)
298

299
            # check if `generate` works
300
            _ = model.generate(input_ids=input_ids, attention_mask=attention_mask, task_ids=task_ids)
301

302
            with pytest.raises(ValueError):
303
                # check if `generate` raises an error if task_ids are not passed
304
                _ = model.generate(input_ids, attention_mask=attention_mask)
305

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.