transformers

Форк
0
/
test_modeling_rwkv.py 
455 строк · 17.3 Кб
1
# coding=utf-8
2
# Copyright 2023 The HuggingFace Team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16

17
import unittest
18
from unittest.util import safe_repr
19

20
from transformers import AutoTokenizer, RwkvConfig, is_torch_available
21
from transformers.testing_utils import require_torch, slow, torch_device
22

23
from ...generation.test_utils import GenerationTesterMixin
24
from ...test_configuration_common import ConfigTester
25
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
26
from ...test_pipeline_mixin import PipelineTesterMixin
27

28

29
if is_torch_available():
30
    import torch
31

32
    from transformers import (
33
        RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
34
        RwkvForCausalLM,
35
        RwkvModel,
36
    )
37
    from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
38
else:
39
    is_torch_greater_or_equal_than_2_0 = False
40

41

42
class RwkvModelTester:
43
    def __init__(
44
        self,
45
        parent,
46
        batch_size=14,
47
        seq_length=7,
48
        is_training=True,
49
        use_token_type_ids=False,
50
        use_input_mask=True,
51
        use_labels=True,
52
        use_mc_token_ids=True,
53
        vocab_size=99,
54
        hidden_size=32,
55
        num_hidden_layers=2,
56
        intermediate_size=37,
57
        hidden_act="gelu",
58
        hidden_dropout_prob=0.1,
59
        attention_probs_dropout_prob=0.1,
60
        max_position_embeddings=512,
61
        type_vocab_size=16,
62
        type_sequence_label_size=2,
63
        num_labels=3,
64
        num_choices=4,
65
        scope=None,
66
    ):
67
        self.parent = parent
68
        self.batch_size = batch_size
69
        self.seq_length = seq_length
70
        self.is_training = is_training
71
        self.use_token_type_ids = use_token_type_ids
72
        self.use_input_mask = use_input_mask
73
        self.use_labels = use_labels
74
        self.use_mc_token_ids = use_mc_token_ids
75
        self.vocab_size = vocab_size
76
        self.hidden_size = hidden_size
77
        self.num_hidden_layers = num_hidden_layers
78
        self.intermediate_size = intermediate_size
79
        self.hidden_act = hidden_act
80
        self.hidden_dropout_prob = hidden_dropout_prob
81
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
82
        self.max_position_embeddings = max_position_embeddings
83
        self.type_vocab_size = type_vocab_size
84
        self.type_sequence_label_size = type_sequence_label_size
85
        self.num_labels = num_labels
86
        self.num_choices = num_choices
87
        self.scope = scope
88
        self.bos_token_id = vocab_size - 1
89
        self.eos_token_id = vocab_size - 1
90
        self.pad_token_id = vocab_size - 1
91

92
    def get_large_model_config(self):
93
        return RwkvConfig.from_pretrained("sgugger/rwkv-4-pile-7b")
94

95
    def prepare_config_and_inputs(
96
        self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
97
    ):
98
        input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
99

100
        input_mask = None
101
        if self.use_input_mask:
102
            input_mask = random_attention_mask([self.batch_size, self.seq_length])
103

104
        token_type_ids = None
105
        if self.use_token_type_ids:
106
            token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
107

108
        mc_token_ids = None
109
        if self.use_mc_token_ids:
110
            mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
111

112
        sequence_labels = None
113
        token_labels = None
114
        choice_labels = None
115
        if self.use_labels:
116
            sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
117
            token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
118
            choice_labels = ids_tensor([self.batch_size], self.num_choices)
119

120
        config = self.get_config(
121
            gradient_checkpointing=gradient_checkpointing,
122
            scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
123
            reorder_and_upcast_attn=reorder_and_upcast_attn,
124
        )
125

126
        return (
127
            config,
128
            input_ids,
129
            input_mask,
130
            None,
131
            token_type_ids,
132
            mc_token_ids,
133
            sequence_labels,
134
            token_labels,
135
            choice_labels,
136
        )
137

138
    def get_config(
139
        self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
140
    ):
141
        return RwkvConfig(
142
            vocab_size=self.vocab_size,
143
            hidden_size=self.hidden_size,
144
            num_hidden_layers=self.num_hidden_layers,
145
            intermediate_size=self.intermediate_size,
146
            activation_function=self.hidden_act,
147
            resid_pdrop=self.hidden_dropout_prob,
148
            attn_pdrop=self.attention_probs_dropout_prob,
149
            n_positions=self.max_position_embeddings,
150
            type_vocab_size=self.type_vocab_size,
151
            use_cache=True,
152
            bos_token_id=self.bos_token_id,
153
            eos_token_id=self.eos_token_id,
154
            pad_token_id=self.pad_token_id,
155
            gradient_checkpointing=gradient_checkpointing,
156
            scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
157
            reorder_and_upcast_attn=reorder_and_upcast_attn,
158
        )
159

160
    def get_pipeline_config(self):
161
        config = self.get_config()
162
        config.vocab_size = 300
163
        return config
164

165
    def prepare_config_and_inputs_for_decoder(self):
166
        (
167
            config,
168
            input_ids,
169
            input_mask,
170
            head_mask,
171
            token_type_ids,
172
            mc_token_ids,
173
            sequence_labels,
174
            token_labels,
175
            choice_labels,
176
        ) = self.prepare_config_and_inputs()
177

178
        encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
179
        encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
180

181
        return (
182
            config,
183
            input_ids,
184
            input_mask,
185
            head_mask,
186
            token_type_ids,
187
            sequence_labels,
188
            token_labels,
189
            choice_labels,
190
            encoder_hidden_states,
191
            encoder_attention_mask,
192
        )
193

194
    def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
195
        config.output_hidden_states = True
196
        model = RwkvModel(config=config)
197
        model.to(torch_device)
198
        model.eval()
199

200
        result = model(input_ids)
201

202
        self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
203
        self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
204

205
    def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
206
        model = RwkvForCausalLM(config)
207
        model.to(torch_device)
208
        model.eval()
209

210
        result = model(input_ids, labels=input_ids)
211
        self.parent.assertEqual(result.loss.shape, ())
212
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
213

214
    def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
215
        model = RwkvModel(config=config)
216
        model.to(torch_device)
217
        model.eval()
218

219
        outputs = model(input_ids)
220
        output_whole = outputs.last_hidden_state
221

222
        outputs = model(input_ids[:, :2])
223
        output_one = outputs.last_hidden_state
224

225
        # Using the state computed on the first inputs, we will get the same output
226
        outputs = model(input_ids[:, 2:], state=outputs.state)
227
        output_two = outputs.last_hidden_state
228

229
        self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
230

231
    def create_and_check_forward_and_backwards(
232
        self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
233
    ):
234
        model = RwkvForCausalLM(config)
235
        model.to(torch_device)
236
        if gradient_checkpointing:
237
            model.gradient_checkpointing_enable()
238

239
        result = model(input_ids, labels=input_ids)
240
        self.parent.assertEqual(result.loss.shape, ())
241
        self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
242
        result.loss.backward()
243

244
    def prepare_config_and_inputs_for_common(self):
245
        config_and_inputs = self.prepare_config_and_inputs()
246

247
        (
248
            config,
249
            input_ids,
250
            input_mask,
251
            head_mask,
252
            token_type_ids,
253
            mc_token_ids,
254
            sequence_labels,
255
            token_labels,
256
            choice_labels,
257
        ) = config_and_inputs
258

259
        inputs_dict = {"input_ids": input_ids}
260

261
        return config, inputs_dict
262

263

264
@unittest.skipIf(
265
    not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
266
)
267
@require_torch
268
class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
269
    all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
270
    pipeline_model_mapping = (
271
        {"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
272
    )
273
    # all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
274
    fx_compatible = False
275
    test_missing_keys = False
276
    test_model_parallel = False
277
    test_pruning = False
278
    test_head_masking = False  # Rwkv does not support head masking
279

280
    def setUp(self):
281
        self.model_tester = RwkvModelTester(self)
282
        self.config_tester = ConfigTester(
283
            self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
284
        )
285

286
    def assertInterval(self, member, container, msg=None):
287
        r"""
288
        Simple utility function to check if a member is inside an interval.
289
        """
290
        if isinstance(member, torch.Tensor):
291
            max_value, min_value = member.max().item(), member.min().item()
292
        elif isinstance(member, list) or isinstance(member, tuple):
293
            max_value, min_value = max(member), min(member)
294

295
        if not isinstance(container, list):
296
            raise TypeError("container should be a list or tuple")
297
        elif len(container) != 2:
298
            raise ValueError("container should have 2 elements")
299

300
        expected_min, expected_max = container
301

302
        is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max)
303

304
        if not is_inside_interval:
305
            standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container))
306
            self.fail(self._formatMessage(msg, standardMsg))
307

308
    def test_config(self):
309
        self.config_tester.run_common_tests()
310

311
    def test_rwkv_model(self):
312
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
313
        self.model_tester.create_and_check_rwkv_model(*config_and_inputs)
314

315
    def test_rwkv_lm_head_model(self):
316
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
317
        self.model_tester.create_and_check_causl_lm(*config_and_inputs)
318

319
    def test_state_equivalency(self):
320
        config_and_inputs = self.model_tester.prepare_config_and_inputs()
321
        self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
322

323
    def test_initialization(self):
324
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
325

326
        for model_class in self.all_model_classes:
327
            model = model_class(config=config)
328
            for name, param in model.named_parameters():
329
                if "time_decay" in name:
330
                    if param.requires_grad:
331
                        self.assertTrue(param.data.max().item() == 3.0)
332
                        self.assertTrue(param.data.min().item() == -5.0)
333
                elif "time_first" in name:
334
                    if param.requires_grad:
335
                        # check if it's a ones like
336
                        self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
337
                elif any(x in name for x in ["time_mix_key", "time_mix_receptance"]):
338
                    if param.requires_grad:
339
                        self.assertInterval(
340
                            param.data,
341
                            [0.0, 1.0],
342
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
343
                        )
344
                elif "time_mix_value" in name:
345
                    if param.requires_grad:
346
                        self.assertInterval(
347
                            param.data,
348
                            [0.0, 1.3],
349
                            msg=f"Parameter {name} of model {model_class} seems not properly initialized",
350
                        )
351

352
    def test_attention_outputs(self):
353
        r"""
354
        Overriding the test_attention_outputs test as the attention outputs of Rwkv are different from other models
355
        it has a shape `batch_size, seq_len, hidden_size`.
356
        """
357
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
358
        config.return_dict = True
359

360
        seq_len = getattr(self.model_tester, "seq_length", None)
361

362
        for model_class in self.all_model_classes:
363
            inputs_dict["output_attentions"] = True
364
            inputs_dict["output_hidden_states"] = False
365
            config.return_dict = True
366
            model = model_class(config)
367
            model.to(torch_device)
368
            model.eval()
369

370
            inputs = self._prepare_for_class(inputs_dict, model_class)
371
            batch_size = inputs["input_ids"].shape[0]
372
            with torch.no_grad():
373
                outputs = model(**inputs)
374
            attentions = outputs.attentions
375
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
376

377
            # check that output_attentions also work using config
378
            del inputs_dict["output_attentions"]
379
            config.output_attentions = True
380
            model = model_class(config)
381
            model.to(torch_device)
382
            model.eval()
383

384
            inputs = self._prepare_for_class(inputs_dict, model_class)
385
            batch_size = inputs["input_ids"].shape[0]
386
            with torch.no_grad():
387
                outputs = model(**inputs)
388
            attentions = outputs.attentions
389
            self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
390

391
            self.assertListEqual(
392
                list(attentions[0].shape[-3:]),
393
                [batch_size, seq_len, config.hidden_size],
394
            )
395
            out_len = len(outputs)
396

397
            # Check attention is always last and order is fine
398
            inputs_dict["output_attentions"] = True
399
            inputs_dict["output_hidden_states"] = True
400
            model = model_class(config)
401
            model.to(torch_device)
402
            model.eval()
403

404
            inputs = self._prepare_for_class(inputs_dict, model_class)
405
            batch_size = inputs["input_ids"].shape[0]
406
            with torch.no_grad():
407
                outputs = model(**inputs)
408

409
            added_hidden_states = 1
410
            self.assertEqual(out_len + added_hidden_states, len(outputs))
411

412
            self_attentions = outputs.attentions
413

414
            self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
415
            self.assertListEqual(
416
                list(self_attentions[0].shape[-3:]),
417
                [batch_size, seq_len, config.hidden_size],
418
            )
419

420
    @slow
421
    def test_model_from_pretrained(self):
422
        for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
423
            model = RwkvModel.from_pretrained(model_name)
424
            self.assertIsNotNone(model)
425

426

427
@unittest.skipIf(
428
    not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
429
)
430
@slow
431
class RWKVIntegrationTests(unittest.TestCase):
432
    def setUp(self):
433
        self.model_id = "RWKV/rwkv-4-169m-pile"
434
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
435

436
    def test_simple_generate(self):
437
        expected_output = "Hello my name is Jasmine and I am a newbie to the"
438
        model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device)
439

440
        input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
441
        output = model.generate(input_ids, max_new_tokens=10)
442
        output_sentence = self.tokenizer.decode(output[0].tolist())
443

444
        self.assertEqual(output_sentence, expected_output)
445

446
    def test_simple_generate_bf16(self):
447
        expected_output = "Hello my name is Jasmine and I am a newbie to the"
448

449
        input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
450
        model = RwkvForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
451

452
        output = model.generate(input_ids, max_new_tokens=10)
453
        output_sentence = self.tokenizer.decode(output[0].tolist())
454

455
        self.assertEqual(output_sentence, expected_output)
456

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.