transformers
455 строк · 17.3 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import unittest
18from unittest.util import safe_repr
19
20from transformers import AutoTokenizer, RwkvConfig, is_torch_available
21from transformers.testing_utils import require_torch, slow, torch_device
22
23from ...generation.test_utils import GenerationTesterMixin
24from ...test_configuration_common import ConfigTester
25from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask
26from ...test_pipeline_mixin import PipelineTesterMixin
27
28
29if is_torch_available():
30import torch
31
32from transformers import (
33RWKV_PRETRAINED_MODEL_ARCHIVE_LIST,
34RwkvForCausalLM,
35RwkvModel,
36)
37from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0
38else:
39is_torch_greater_or_equal_than_2_0 = False
40
41
42class RwkvModelTester:
43def __init__(
44self,
45parent,
46batch_size=14,
47seq_length=7,
48is_training=True,
49use_token_type_ids=False,
50use_input_mask=True,
51use_labels=True,
52use_mc_token_ids=True,
53vocab_size=99,
54hidden_size=32,
55num_hidden_layers=2,
56intermediate_size=37,
57hidden_act="gelu",
58hidden_dropout_prob=0.1,
59attention_probs_dropout_prob=0.1,
60max_position_embeddings=512,
61type_vocab_size=16,
62type_sequence_label_size=2,
63num_labels=3,
64num_choices=4,
65scope=None,
66):
67self.parent = parent
68self.batch_size = batch_size
69self.seq_length = seq_length
70self.is_training = is_training
71self.use_token_type_ids = use_token_type_ids
72self.use_input_mask = use_input_mask
73self.use_labels = use_labels
74self.use_mc_token_ids = use_mc_token_ids
75self.vocab_size = vocab_size
76self.hidden_size = hidden_size
77self.num_hidden_layers = num_hidden_layers
78self.intermediate_size = intermediate_size
79self.hidden_act = hidden_act
80self.hidden_dropout_prob = hidden_dropout_prob
81self.attention_probs_dropout_prob = attention_probs_dropout_prob
82self.max_position_embeddings = max_position_embeddings
83self.type_vocab_size = type_vocab_size
84self.type_sequence_label_size = type_sequence_label_size
85self.num_labels = num_labels
86self.num_choices = num_choices
87self.scope = scope
88self.bos_token_id = vocab_size - 1
89self.eos_token_id = vocab_size - 1
90self.pad_token_id = vocab_size - 1
91
92def get_large_model_config(self):
93return RwkvConfig.from_pretrained("sgugger/rwkv-4-pile-7b")
94
95def prepare_config_and_inputs(
96self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
97):
98input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
99
100input_mask = None
101if self.use_input_mask:
102input_mask = random_attention_mask([self.batch_size, self.seq_length])
103
104token_type_ids = None
105if self.use_token_type_ids:
106token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
107
108mc_token_ids = None
109if self.use_mc_token_ids:
110mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length)
111
112sequence_labels = None
113token_labels = None
114choice_labels = None
115if self.use_labels:
116sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
117token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
118choice_labels = ids_tensor([self.batch_size], self.num_choices)
119
120config = self.get_config(
121gradient_checkpointing=gradient_checkpointing,
122scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
123reorder_and_upcast_attn=reorder_and_upcast_attn,
124)
125
126return (
127config,
128input_ids,
129input_mask,
130None,
131token_type_ids,
132mc_token_ids,
133sequence_labels,
134token_labels,
135choice_labels,
136)
137
138def get_config(
139self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False
140):
141return RwkvConfig(
142vocab_size=self.vocab_size,
143hidden_size=self.hidden_size,
144num_hidden_layers=self.num_hidden_layers,
145intermediate_size=self.intermediate_size,
146activation_function=self.hidden_act,
147resid_pdrop=self.hidden_dropout_prob,
148attn_pdrop=self.attention_probs_dropout_prob,
149n_positions=self.max_position_embeddings,
150type_vocab_size=self.type_vocab_size,
151use_cache=True,
152bos_token_id=self.bos_token_id,
153eos_token_id=self.eos_token_id,
154pad_token_id=self.pad_token_id,
155gradient_checkpointing=gradient_checkpointing,
156scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
157reorder_and_upcast_attn=reorder_and_upcast_attn,
158)
159
160def get_pipeline_config(self):
161config = self.get_config()
162config.vocab_size = 300
163return config
164
165def prepare_config_and_inputs_for_decoder(self):
166(
167config,
168input_ids,
169input_mask,
170head_mask,
171token_type_ids,
172mc_token_ids,
173sequence_labels,
174token_labels,
175choice_labels,
176) = self.prepare_config_and_inputs()
177
178encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
179encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
180
181return (
182config,
183input_ids,
184input_mask,
185head_mask,
186token_type_ids,
187sequence_labels,
188token_labels,
189choice_labels,
190encoder_hidden_states,
191encoder_attention_mask,
192)
193
194def create_and_check_rwkv_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
195config.output_hidden_states = True
196model = RwkvModel(config=config)
197model.to(torch_device)
198model.eval()
199
200result = model(input_ids)
201
202self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
203self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1)
204
205def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
206model = RwkvForCausalLM(config)
207model.to(torch_device)
208model.eval()
209
210result = model(input_ids, labels=input_ids)
211self.parent.assertEqual(result.loss.shape, ())
212self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
213
214def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
215model = RwkvModel(config=config)
216model.to(torch_device)
217model.eval()
218
219outputs = model(input_ids)
220output_whole = outputs.last_hidden_state
221
222outputs = model(input_ids[:, :2])
223output_one = outputs.last_hidden_state
224
225# Using the state computed on the first inputs, we will get the same output
226outputs = model(input_ids[:, 2:], state=outputs.state)
227output_two = outputs.last_hidden_state
228
229self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5))
230
231def create_and_check_forward_and_backwards(
232self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False
233):
234model = RwkvForCausalLM(config)
235model.to(torch_device)
236if gradient_checkpointing:
237model.gradient_checkpointing_enable()
238
239result = model(input_ids, labels=input_ids)
240self.parent.assertEqual(result.loss.shape, ())
241self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
242result.loss.backward()
243
244def prepare_config_and_inputs_for_common(self):
245config_and_inputs = self.prepare_config_and_inputs()
246
247(
248config,
249input_ids,
250input_mask,
251head_mask,
252token_type_ids,
253mc_token_ids,
254sequence_labels,
255token_labels,
256choice_labels,
257) = config_and_inputs
258
259inputs_dict = {"input_ids": input_ids}
260
261return config, inputs_dict
262
263
264@unittest.skipIf(
265not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
266)
267@require_torch
268class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
269all_model_classes = (RwkvModel, RwkvForCausalLM) if is_torch_available() else ()
270pipeline_model_mapping = (
271{"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
272)
273# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
274fx_compatible = False
275test_missing_keys = False
276test_model_parallel = False
277test_pruning = False
278test_head_masking = False # Rwkv does not support head masking
279
280def setUp(self):
281self.model_tester = RwkvModelTester(self)
282self.config_tester = ConfigTester(
283self, config_class=RwkvConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
284)
285
286def assertInterval(self, member, container, msg=None):
287r"""
288Simple utility function to check if a member is inside an interval.
289"""
290if isinstance(member, torch.Tensor):
291max_value, min_value = member.max().item(), member.min().item()
292elif isinstance(member, list) or isinstance(member, tuple):
293max_value, min_value = max(member), min(member)
294
295if not isinstance(container, list):
296raise TypeError("container should be a list or tuple")
297elif len(container) != 2:
298raise ValueError("container should have 2 elements")
299
300expected_min, expected_max = container
301
302is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max)
303
304if not is_inside_interval:
305standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container))
306self.fail(self._formatMessage(msg, standardMsg))
307
308def test_config(self):
309self.config_tester.run_common_tests()
310
311def test_rwkv_model(self):
312config_and_inputs = self.model_tester.prepare_config_and_inputs()
313self.model_tester.create_and_check_rwkv_model(*config_and_inputs)
314
315def test_rwkv_lm_head_model(self):
316config_and_inputs = self.model_tester.prepare_config_and_inputs()
317self.model_tester.create_and_check_causl_lm(*config_and_inputs)
318
319def test_state_equivalency(self):
320config_and_inputs = self.model_tester.prepare_config_and_inputs()
321self.model_tester.create_and_check_state_equivalency(*config_and_inputs)
322
323def test_initialization(self):
324config, _ = self.model_tester.prepare_config_and_inputs_for_common()
325
326for model_class in self.all_model_classes:
327model = model_class(config=config)
328for name, param in model.named_parameters():
329if "time_decay" in name:
330if param.requires_grad:
331self.assertTrue(param.data.max().item() == 3.0)
332self.assertTrue(param.data.min().item() == -5.0)
333elif "time_first" in name:
334if param.requires_grad:
335# check if it's a ones like
336self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
337elif any(x in name for x in ["time_mix_key", "time_mix_receptance"]):
338if param.requires_grad:
339self.assertInterval(
340param.data,
341[0.0, 1.0],
342msg=f"Parameter {name} of model {model_class} seems not properly initialized",
343)
344elif "time_mix_value" in name:
345if param.requires_grad:
346self.assertInterval(
347param.data,
348[0.0, 1.3],
349msg=f"Parameter {name} of model {model_class} seems not properly initialized",
350)
351
352def test_attention_outputs(self):
353r"""
354Overriding the test_attention_outputs test as the attention outputs of Rwkv are different from other models
355it has a shape `batch_size, seq_len, hidden_size`.
356"""
357config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
358config.return_dict = True
359
360seq_len = getattr(self.model_tester, "seq_length", None)
361
362for model_class in self.all_model_classes:
363inputs_dict["output_attentions"] = True
364inputs_dict["output_hidden_states"] = False
365config.return_dict = True
366model = model_class(config)
367model.to(torch_device)
368model.eval()
369
370inputs = self._prepare_for_class(inputs_dict, model_class)
371batch_size = inputs["input_ids"].shape[0]
372with torch.no_grad():
373outputs = model(**inputs)
374attentions = outputs.attentions
375self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
376
377# check that output_attentions also work using config
378del inputs_dict["output_attentions"]
379config.output_attentions = True
380model = model_class(config)
381model.to(torch_device)
382model.eval()
383
384inputs = self._prepare_for_class(inputs_dict, model_class)
385batch_size = inputs["input_ids"].shape[0]
386with torch.no_grad():
387outputs = model(**inputs)
388attentions = outputs.attentions
389self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
390
391self.assertListEqual(
392list(attentions[0].shape[-3:]),
393[batch_size, seq_len, config.hidden_size],
394)
395out_len = len(outputs)
396
397# Check attention is always last and order is fine
398inputs_dict["output_attentions"] = True
399inputs_dict["output_hidden_states"] = True
400model = model_class(config)
401model.to(torch_device)
402model.eval()
403
404inputs = self._prepare_for_class(inputs_dict, model_class)
405batch_size = inputs["input_ids"].shape[0]
406with torch.no_grad():
407outputs = model(**inputs)
408
409added_hidden_states = 1
410self.assertEqual(out_len + added_hidden_states, len(outputs))
411
412self_attentions = outputs.attentions
413
414self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
415self.assertListEqual(
416list(self_attentions[0].shape[-3:]),
417[batch_size, seq_len, config.hidden_size],
418)
419
420@slow
421def test_model_from_pretrained(self):
422for model_name in RWKV_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
423model = RwkvModel.from_pretrained(model_name)
424self.assertIsNotNone(model)
425
426
427@unittest.skipIf(
428not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
429)
430@slow
431class RWKVIntegrationTests(unittest.TestCase):
432def setUp(self):
433self.model_id = "RWKV/rwkv-4-169m-pile"
434self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
435
436def test_simple_generate(self):
437expected_output = "Hello my name is Jasmine and I am a newbie to the"
438model = RwkvForCausalLM.from_pretrained(self.model_id).to(torch_device)
439
440input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
441output = model.generate(input_ids, max_new_tokens=10)
442output_sentence = self.tokenizer.decode(output[0].tolist())
443
444self.assertEqual(output_sentence, expected_output)
445
446def test_simple_generate_bf16(self):
447expected_output = "Hello my name is Jasmine and I am a newbie to the"
448
449input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device)
450model = RwkvForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device)
451
452output = model.generate(input_ids, max_new_tokens=10)
453output_sentence = self.tokenizer.decode(output[0].tolist())
454
455self.assertEqual(output_sentence, expected_output)
456