intel-extension-for-pytorch
64 строки · 2.3 Кб
1from transformers import PretrainedConfig
2
3
4class ChatGLMConfig(PretrainedConfig):
5model_type = "chatglm"
6
7def __init__(
8self,
9num_layers=28,
10padded_vocab_size=65024,
11hidden_size=4096,
12ffn_hidden_size=13696,
13kv_channels=128,
14num_attention_heads=32,
15seq_length=2048,
16hidden_dropout=0.0,
17classifier_dropout=None,
18attention_dropout=0.0,
19layernorm_epsilon=1e-5,
20rmsnorm=True,
21apply_residual_connection_post_layernorm=False,
22post_layer_norm=True,
23add_bias_linear=False,
24add_qkv_bias=False,
25bias_dropout_fusion=True,
26multi_query_attention=False,
27multi_query_group_num=1,
28apply_query_key_layer_scaling=True,
29attention_softmax_in_fp32=True,
30fp32_residual_connection=False,
31quantization_bit=0,
32pre_seq_len=None,
33prefix_projection=False,
34**kwargs
35):
36self.num_layers = num_layers
37self.vocab_size = padded_vocab_size
38self.padded_vocab_size = padded_vocab_size
39self.hidden_size = hidden_size
40self.ffn_hidden_size = ffn_hidden_size
41self.kv_channels = kv_channels
42self.num_attention_heads = num_attention_heads
43self.seq_length = seq_length
44self.hidden_dropout = hidden_dropout
45self.classifier_dropout = classifier_dropout
46self.attention_dropout = attention_dropout
47self.layernorm_epsilon = layernorm_epsilon
48self.rmsnorm = rmsnorm
49self.apply_residual_connection_post_layernorm = (
50apply_residual_connection_post_layernorm
51)
52self.post_layer_norm = post_layer_norm
53self.add_bias_linear = add_bias_linear
54self.add_qkv_bias = add_qkv_bias
55self.bias_dropout_fusion = bias_dropout_fusion
56self.multi_query_attention = multi_query_attention
57self.multi_query_group_num = multi_query_group_num
58self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
59self.attention_softmax_in_fp32 = attention_softmax_in_fp32
60self.fp32_residual_connection = fp32_residual_connection
61self.quantization_bit = quantization_bit
62self.pre_seq_len = pre_seq_len
63self.prefix_projection = prefix_projection
64super().__init__(**kwargs)
65