intel-extension-for-pytorch
1592 строки · 57.4 Кб
1# encoding: UTF-8
2import math
3import copy
4import warnings
5
6import torch
7import torch.utils.checkpoint
8import torch.nn.functional as F
9from torch import nn
10from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
11from torch.nn.utils import skip_init
12from typing import Optional, Tuple, Union, List, Callable, Dict, Any
13from copy import deepcopy
14
15from transformers.modeling_outputs import (
16BaseModelOutputWithPast,
17CausalLMOutputWithPast,
18SequenceClassifierOutputWithPast,
19)
20from transformers.modeling_utils import PreTrainedModel
21from transformers.utils import logging
22from transformers.generation.logits_process import LogitsProcessor
23from transformers.generation.utils import (
24LogitsProcessorList,
25StoppingCriteriaList,
26GenerationConfig,
27ModelOutput,
28)
29
30from .configuration_chatglm import ChatGLMConfig
31
32# flags required to enable jit fusion kernels
33# import sys
34# if sys.platform != "darwin":
35# torch._C._jit_set_profiling_mode(False)
36# torch._C._jit_set_profiling_executor(False)
37# torch._C._jit_override_can_fuse_on_cpu(True)
38# torch._C._jit_override_can_fuse_on_gpu(True)
39
40logger = logging.get_logger(__name__)
41
42_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM"
43_CONFIG_FOR_DOC = "ChatGLMConfig"
44
45CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
46"THUDM/chatglm3-6b",
47# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
48]
49
50
51def default_init(cls, *args, **kwargs):
52return cls(*args, **kwargs)
53
54
55class InvalidScoreLogitsProcessor(LogitsProcessor):
56def __call__(
57self, input_ids: torch.LongTensor, scores: torch.FloatTensor
58) -> torch.FloatTensor:
59if torch.isnan(scores).any() or torch.isinf(scores).any():
60scores.zero_()
61scores[..., 5] = 5e4
62return scores
63
64
65class PrefixEncoder(torch.nn.Module):
66"""
67The torch.nn model to encode the prefix
68Input shape: (batch-size, prefix-length)
69Output shape: (batch-size, prefix-length, 2*layers*hidden)
70"""
71
72def __init__(self, config: ChatGLMConfig):
73super().__init__()
74self.prefix_projection = config.prefix_projection
75if self.prefix_projection:
76# Use a two-layer MLP to encode the prefix
77kv_size = (
78config.num_layers
79* config.kv_channels
80* config.multi_query_group_num
81* 2
82)
83self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size)
84self.trans = torch.nn.Sequential(
85torch.nn.Linear(kv_size, config.hidden_size),
86torch.nn.Tanh(),
87torch.nn.Linear(config.hidden_size, kv_size),
88)
89else:
90self.embedding = torch.nn.Embedding(
91config.pre_seq_len,
92config.num_layers
93* config.kv_channels
94* config.multi_query_group_num
95* 2,
96)
97
98def forward(self, prefix: torch.Tensor):
99if self.prefix_projection:
100prefix_tokens = self.embedding(prefix)
101past_key_values = self.trans(prefix_tokens)
102else:
103past_key_values = self.embedding(prefix)
104return past_key_values
105
106
107def split_tensor_along_last_dim(
108tensor: torch.Tensor,
109num_partitions: int,
110contiguous_split_chunks: bool = False,
111) -> List[torch.Tensor]:
112"""Split a tensor along its last dimension.
113Arguments:
114tensor: input tensor.
115num_partitions: number of partitions to split the tensor
116contiguous_split_chunks: If True, make each chunk contiguous
117in memory.
118Returns:
119A list of Tensors
120"""
121# Get the size and dimension.
122last_dim = tensor.dim() - 1
123last_dim_size = tensor.size()[last_dim] // num_partitions
124# Split.
125tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
126# Note: torch.split does not create contiguous tensors by default.
127if contiguous_split_chunks:
128return tuple(chunk.contiguous() for chunk in tensor_list)
129
130return tensor_list
131
132
133class RotaryEmbedding(nn.Module):
134def __init__(self, dim, original_impl=False, device=None, dtype=None):
135super().__init__()
136inv_freq = 1.0 / (
13710000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)
138)
139self.register_buffer("inv_freq", inv_freq)
140self.dim = dim
141self.original_impl = original_impl
142
143def forward_impl(
144self,
145seq_len: int,
146n_elem: int,
147dtype: torch.dtype,
148device: torch.device,
149base: int = 10000,
150):
151"""Enhanced Transformer with Rotary Position Embedding.
152Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
153transformers/rope/__init__.py. MIT License:
154https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
155"""
156# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
157theta = 1.0 / (
158base
159** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)
160)
161
162# Create position indexes `[0, 1, ..., seq_len - 1]`
163seq_idx = torch.arange(seq_len, dtype=torch.float, device=device)
164
165# Calculate the product of position index and $\theta_i$
166idx_theta = torch.outer(seq_idx, theta).float()
167
168cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
169
170# this is to mimic the behaviour of complex32, else we will get different results
171if dtype in (torch.float16, torch.bfloat16, torch.int8):
172cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
173return cache
174
175def forward(self, max_seq_len, offset=0):
176return self.forward_impl(
177max_seq_len,
178self.dim,
179dtype=self.inv_freq.dtype,
180device=self.inv_freq.device,
181)
182
183
184@torch.jit.script
185def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
186# x: [sq, b, np, hn]
187sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
188rot_dim = rope_cache.shape[-2] * 2
189x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
190# truncate to support variable sizes
191rope_cache = rope_cache[:sq]
192xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
193rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
194x_out2 = torch.stack(
195[
196xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
197xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
198],
199-1,
200)
201x_out2 = x_out2.flatten(3)
202return torch.cat((x_out2, x_pass), dim=-1)
203
204
205class RMSNorm(torch.nn.Module):
206def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
207super().__init__()
208self.weight = torch.nn.Parameter(
209torch.empty(normalized_shape, device=device, dtype=dtype)
210)
211self.eps = eps
212
213def forward(self, hidden_states: torch.Tensor):
214input_dtype = hidden_states.dtype
215variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
216hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
217
218return (self.weight * hidden_states).to(input_dtype)
219
220
221class CoreAttention(torch.nn.Module):
222def __init__(self, config: ChatGLMConfig, layer_number):
223super(CoreAttention, self).__init__()
224
225self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
226self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
227if self.apply_query_key_layer_scaling:
228self.attention_softmax_in_fp32 = True
229self.layer_number = max(1, layer_number)
230
231projection_size = config.kv_channels * config.num_attention_heads
232
233# Per attention head and per partition values.
234self.hidden_size_per_partition = projection_size
235self.hidden_size_per_attention_head = (
236projection_size // config.num_attention_heads
237)
238self.num_attention_heads_per_partition = config.num_attention_heads
239
240coeff = None
241self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
242if self.apply_query_key_layer_scaling:
243coeff = self.layer_number
244self.norm_factor *= coeff
245self.coeff = coeff
246
247self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
248
249def forward(self, query_layer, key_layer, value_layer, attention_mask):
250pytorch_major_version = int(torch.__version__.split(".")[0])
251if pytorch_major_version >= 2:
252query_layer, key_layer, value_layer = [
253k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]
254]
255key_layer = key_layer.to(query_layer.dtype)
256value_layer = value_layer.to(query_layer.dtype)
257if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
258context_layer = torch.nn.functional.scaled_dot_product_attention(
259query_layer, key_layer, value_layer, is_causal=True
260)
261else:
262if attention_mask is not None:
263attention_mask = ~attention_mask
264context_layer = torch.nn.functional.scaled_dot_product_attention(
265query_layer, key_layer, value_layer, attention_mask
266)
267context_layer = context_layer.permute(2, 0, 1, 3)
268new_context_layer_shape = context_layer.size()[:-2] + (
269self.hidden_size_per_partition,
270)
271context_layer = context_layer.reshape(*new_context_layer_shape)
272else:
273# Raw attention scores
274
275# [b, np, sq, sk]
276output_size = (
277query_layer.size(1),
278query_layer.size(2),
279query_layer.size(0),
280key_layer.size(0),
281)
282
283# [sq, b, np, hn] -> [sq, b * np, hn]
284query_layer = query_layer.view(
285output_size[2], output_size[0] * output_size[1], -1
286)
287# [sk, b, np, hn] -> [sk, b * np, hn]
288key_layer = key_layer.view(
289output_size[3], output_size[0] * output_size[1], -1
290)
291
292# preallocting input tensor: [b * np, sq, sk]
293matmul_input_buffer = torch.empty(
294output_size[0] * output_size[1],
295output_size[2],
296output_size[3],
297dtype=query_layer.dtype,
298device=query_layer.device,
299)
300
301# Raw attention scores. [b * np, sq, sk]
302matmul_result = torch.baddbmm(
303matmul_input_buffer,
304query_layer.transpose(0, 1), # [b * np, sq, hn]
305key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
306beta=0.0,
307alpha=(1.0 / self.norm_factor),
308)
309
310# change view to [b, np, sq, sk]
311attention_scores = matmul_result.view(*output_size)
312
313# ===========================
314# Attention probs and dropout
315# ===========================
316
317# attention scores and attention mask [b, np, sq, sk]
318if self.attention_softmax_in_fp32:
319attention_scores = attention_scores.float()
320if self.coeff is not None:
321attention_scores = attention_scores * self.coeff
322if (
323attention_mask is None
324and attention_scores.shape[2] == attention_scores.shape[3]
325):
326attention_mask = torch.ones(
327output_size[0],
3281,
329output_size[2],
330output_size[3],
331device=attention_scores.device,
332dtype=torch.bool,
333)
334attention_mask.tril_()
335attention_mask = ~attention_mask
336if attention_mask is not None:
337attention_scores = attention_scores.masked_fill(
338attention_mask, float("-inf")
339)
340attention_probs = F.softmax(attention_scores, dim=-1)
341attention_probs = attention_probs.type_as(value_layer)
342
343# This is actually dropping out entire tokens to attend to, which might
344# seem a bit unusual, but is taken from the original Transformer paper.
345attention_probs = self.attention_dropout(attention_probs)
346# =========================
347# Context layer. [sq, b, hp]
348# =========================
349
350# value_layer -> context layer.
351# [sk, b, np, hn] --> [b, np, sq, hn]
352
353# context layer shape: [b, np, sq, hn]
354output_size = (
355value_layer.size(1),
356value_layer.size(2),
357query_layer.size(0),
358value_layer.size(3),
359)
360# change view [sk, b * np, hn]
361value_layer = value_layer.view(
362value_layer.size(0), output_size[0] * output_size[1], -1
363)
364# change view [b * np, sq, sk]
365attention_probs = attention_probs.view(
366output_size[0] * output_size[1], output_size[2], -1
367)
368# matmul: [b * np, sq, hn]
369context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
370# change view [b, np, sq, hn]
371context_layer = context_layer.view(*output_size)
372# [b, np, sq, hn] --> [sq, b, np, hn]
373context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
374# [sq, b, np, hn] --> [sq, b, hp]
375new_context_layer_shape = context_layer.size()[:-2] + (
376self.hidden_size_per_partition,
377)
378context_layer = context_layer.view(*new_context_layer_shape)
379
380return context_layer
381
382
383class SelfAttention(torch.nn.Module):
384"""Parallel self-attention layer abstract class.
385Self-attention layer takes input with size [s, b, h]
386and returns output of the same size.
387"""
388
389def __init__(self, config: ChatGLMConfig, layer_number, device=None):
390super(SelfAttention, self).__init__()
391self.layer_number = max(1, layer_number)
392
393self.projection_size = config.kv_channels * config.num_attention_heads
394
395# Per attention head and per partition values.
396self.hidden_size_per_attention_head = (
397self.projection_size // config.num_attention_heads
398)
399self.num_attention_heads_per_partition = config.num_attention_heads
400
401self.multi_query_attention = config.multi_query_attention
402self.qkv_hidden_size = 3 * self.projection_size
403if self.multi_query_attention:
404self.num_multi_query_groups_per_partition = config.multi_query_group_num
405self.qkv_hidden_size = (
406self.projection_size
407+ 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
408)
409self.query_key_value = nn.Linear(
410config.hidden_size,
411self.qkv_hidden_size,
412bias=config.add_bias_linear or config.add_qkv_bias,
413device=device,
414**_config_to_kwargs(config),
415)
416
417self.core_attention = CoreAttention(config, self.layer_number)
418
419# Output.
420self.dense = nn.Linear(
421self.projection_size,
422config.hidden_size,
423bias=config.add_bias_linear,
424device=device,
425**_config_to_kwargs(config),
426)
427
428def _allocate_memory(
429self, inference_max_sequence_len, batch_size, device=None, dtype=None
430):
431if self.multi_query_attention:
432num_attention_heads = self.num_multi_query_groups_per_partition
433else:
434num_attention_heads = self.num_attention_heads_per_partition
435return torch.empty(
436inference_max_sequence_len,
437batch_size,
438num_attention_heads,
439self.hidden_size_per_attention_head,
440dtype=dtype,
441device=device,
442)
443
444def forward(
445self,
446hidden_states,
447attention_mask,
448rotary_pos_emb,
449kv_cache=None,
450use_cache=True,
451):
452# hidden_states: [sq, b, h]
453
454# =================================================
455# Pre-allocate memory for key-values for inference.
456# =================================================
457# =====================
458# Query, Key, and Value
459# =====================
460
461# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
462mixed_x_layer = self.query_key_value(hidden_states)
463
464if self.multi_query_attention:
465(query_layer, key_layer, value_layer) = mixed_x_layer.split(
466[
467self.num_attention_heads_per_partition
468* self.hidden_size_per_attention_head,
469self.num_multi_query_groups_per_partition
470* self.hidden_size_per_attention_head,
471self.num_multi_query_groups_per_partition
472* self.hidden_size_per_attention_head,
473],
474dim=-1,
475)
476query_layer = query_layer.view(
477query_layer.size()[:-1]
478+ (
479self.num_attention_heads_per_partition,
480self.hidden_size_per_attention_head,
481)
482)
483key_layer = key_layer.view(
484key_layer.size()[:-1]
485+ (
486self.num_multi_query_groups_per_partition,
487self.hidden_size_per_attention_head,
488)
489)
490value_layer = value_layer.view(
491value_layer.size()[:-1]
492+ (
493self.num_multi_query_groups_per_partition,
494self.hidden_size_per_attention_head,
495)
496)
497else:
498new_tensor_shape = mixed_x_layer.size()[:-1] + (
499self.num_attention_heads_per_partition,
5003 * self.hidden_size_per_attention_head,
501)
502mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
503
504# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
505(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(
506mixed_x_layer, 3
507)
508
509# apply relative positional encoding (rotary embedding)
510if rotary_pos_emb is not None:
511query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
512key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
513
514# adjust key and value for inference
515if kv_cache is not None:
516cache_k, cache_v = kv_cache
517key_layer = torch.cat((cache_k, key_layer), dim=0)
518value_layer = torch.cat((cache_v, value_layer), dim=0)
519if use_cache:
520kv_cache = (key_layer, value_layer)
521else:
522kv_cache = None
523
524if self.multi_query_attention:
525key_layer = key_layer.unsqueeze(-2)
526key_layer = key_layer.expand(
527-1,
528-1,
529-1,
530self.num_attention_heads_per_partition
531// self.num_multi_query_groups_per_partition,
532-1,
533)
534key_layer = key_layer.contiguous().view(
535key_layer.size()[:2]
536+ (
537self.num_attention_heads_per_partition,
538self.hidden_size_per_attention_head,
539)
540)
541value_layer = value_layer.unsqueeze(-2)
542value_layer = value_layer.expand(
543-1,
544-1,
545-1,
546self.num_attention_heads_per_partition
547// self.num_multi_query_groups_per_partition,
548-1,
549)
550value_layer = value_layer.contiguous().view(
551value_layer.size()[:2]
552+ (
553self.num_attention_heads_per_partition,
554self.hidden_size_per_attention_head,
555)
556)
557
558# ==================================
559# core attention computation
560# ==================================
561
562context_layer = self.core_attention(
563query_layer, key_layer, value_layer, attention_mask
564)
565
566# =================
567# Output. [sq, b, h]
568# =================
569
570output = self.dense(context_layer)
571
572return output, kv_cache
573
574
575def _config_to_kwargs(args):
576common_kwargs = {
577"dtype": args.torch_dtype,
578}
579return common_kwargs
580
581
582class MLP(torch.nn.Module):
583"""MLP.
584MLP will take the input with h hidden state, project it to 4*h
585hidden dimension, perform nonlinear transformation, and project the
586state back into h hidden dimension.
587"""
588
589def __init__(self, config: ChatGLMConfig, device=None):
590super(MLP, self).__init__()
591
592self.add_bias = config.add_bias_linear
593
594# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
595self.dense_h_to_4h = nn.Linear(
596config.hidden_size,
597config.ffn_hidden_size * 2,
598bias=self.add_bias,
599device=device,
600**_config_to_kwargs(config),
601)
602
603def swiglu(x):
604x = torch.chunk(x, 2, dim=-1)
605return F.silu(x[0]) * x[1]
606
607self.activation_func = swiglu
608
609# Project back to h.
610self.dense_4h_to_h = nn.Linear(
611config.ffn_hidden_size,
612config.hidden_size,
613bias=self.add_bias,
614device=device,
615**_config_to_kwargs(config),
616)
617
618def forward(self, hidden_states):
619# [s, b, 4hp]
620intermediate_parallel = self.dense_h_to_4h(hidden_states)
621intermediate_parallel = self.activation_func(intermediate_parallel)
622# [s, b, h]
623output = self.dense_4h_to_h(intermediate_parallel)
624return output
625
626
627class GLMBlock(torch.nn.Module):
628"""A single transformer layer.
629Transformer layer takes input with size [s, b, h] and returns an
630output of the same size.
631"""
632
633def __init__(self, config: ChatGLMConfig, layer_number, device=None):
634super(GLMBlock, self).__init__()
635self.layer_number = layer_number
636
637self.apply_residual_connection_post_layernorm = (
638config.apply_residual_connection_post_layernorm
639)
640
641self.fp32_residual_connection = config.fp32_residual_connection
642
643LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
644# Layernorm on the input data.
645self.input_layernorm = LayerNormFunc(
646config.hidden_size,
647eps=config.layernorm_epsilon,
648device=device,
649dtype=config.torch_dtype,
650)
651
652# Self attention.
653self.self_attention = SelfAttention(config, layer_number, device=device)
654self.hidden_dropout = config.hidden_dropout
655
656# Layernorm on the attention output
657self.post_attention_layernorm = LayerNormFunc(
658config.hidden_size,
659eps=config.layernorm_epsilon,
660device=device,
661dtype=config.torch_dtype,
662)
663
664# MLP
665self.mlp = MLP(config, device=device)
666
667def forward(
668self,
669hidden_states,
670attention_mask,
671rotary_pos_emb,
672kv_cache=None,
673use_cache=True,
674):
675# hidden_states: [s, b, h]
676
677# Layer norm at the beginning of the transformer layer.
678layernorm_output = self.input_layernorm(hidden_states)
679# Self attention.
680attention_output, kv_cache = self.self_attention(
681layernorm_output,
682attention_mask,
683rotary_pos_emb,
684kv_cache=kv_cache,
685use_cache=use_cache,
686)
687
688# Residual connection.
689if self.apply_residual_connection_post_layernorm:
690residual = layernorm_output
691else:
692residual = hidden_states
693
694layernorm_input = torch.nn.functional.dropout(
695attention_output, p=self.hidden_dropout, training=self.training
696)
697layernorm_input = residual + layernorm_input
698
699# Layer norm post the self attention.
700layernorm_output = self.post_attention_layernorm(layernorm_input)
701
702# MLP.
703mlp_output = self.mlp(layernorm_output)
704
705# Second residual connection.
706if self.apply_residual_connection_post_layernorm:
707residual = layernorm_output
708else:
709residual = layernorm_input
710
711output = torch.nn.functional.dropout(
712mlp_output, p=self.hidden_dropout, training=self.training
713)
714output = residual + output
715
716return output, kv_cache
717
718
719class GLMTransformer(torch.nn.Module):
720"""Transformer class."""
721
722def __init__(self, config: ChatGLMConfig, device=None):
723super(GLMTransformer, self).__init__()
724
725self.fp32_residual_connection = config.fp32_residual_connection
726self.post_layer_norm = config.post_layer_norm
727
728# Number of layers.
729self.num_layers = config.num_layers
730
731# Transformer layers.
732def build_layer(layer_number):
733return GLMBlock(config, layer_number, device=device)
734
735self.layers = torch.nn.ModuleList(
736[build_layer(i + 1) for i in range(self.num_layers)]
737)
738
739if self.post_layer_norm:
740LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
741# Final layer norm before output.
742self.final_layernorm = LayerNormFunc(
743config.hidden_size,
744eps=config.layernorm_epsilon,
745device=device,
746dtype=config.torch_dtype,
747)
748
749self.gradient_checkpointing = False
750
751def _get_layer(self, layer_number):
752return self.layers[layer_number]
753
754def forward(
755self,
756hidden_states,
757attention_mask,
758rotary_pos_emb,
759kv_caches=None,
760use_cache: Optional[bool] = True,
761output_hidden_states: Optional[bool] = False,
762):
763if not kv_caches:
764kv_caches = [None for _ in range(self.num_layers)]
765presents = () if use_cache else None
766if self.gradient_checkpointing and self.training:
767if use_cache:
768logger.warning_once(
769"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
770)
771use_cache = False
772
773all_self_attentions = None
774all_hidden_states = () if output_hidden_states else None
775for index in range(self.num_layers):
776if output_hidden_states:
777all_hidden_states = all_hidden_states + (hidden_states,)
778
779layer = self._get_layer(index)
780if self.gradient_checkpointing and self.training:
781layer_ret = torch.utils.checkpoint.checkpoint(
782layer,
783hidden_states,
784attention_mask,
785rotary_pos_emb,
786kv_caches[index],
787use_cache,
788)
789else:
790layer_ret = layer(
791hidden_states,
792attention_mask,
793rotary_pos_emb,
794kv_cache=kv_caches[index],
795use_cache=use_cache,
796)
797hidden_states, kv_cache = layer_ret
798if use_cache:
799presents = presents + (kv_cache,)
800
801if output_hidden_states:
802all_hidden_states = all_hidden_states + (hidden_states,)
803
804# Final layer norm.
805if self.post_layer_norm:
806hidden_states = self.final_layernorm(hidden_states)
807
808return hidden_states, presents, all_hidden_states, all_self_attentions
809
810
811class ChatGLMPreTrainedModel(PreTrainedModel):
812"""
813An abstract class to handle weights initialization and
814a simple interface for downloading and loading pretrained models.
815"""
816
817is_parallelizable = False
818supports_gradient_checkpointing = True
819config_class = ChatGLMConfig
820base_model_prefix = "transformer"
821_no_split_modules = ["GLMBlock"]
822
823def _init_weights(self, module: nn.Module):
824"""Initialize the weights."""
825return
826
827def get_masks(self, input_ids, past_key_values, padding_mask=None):
828batch_size, seq_length = input_ids.shape
829full_attention_mask = torch.ones(
830batch_size, seq_length, seq_length, device=input_ids.device
831)
832full_attention_mask.tril_()
833past_length = 0
834if past_key_values:
835past_length = past_key_values[0][0].shape[0]
836if past_length:
837full_attention_mask = torch.cat(
838(
839torch.ones(
840batch_size, seq_length, past_length, device=input_ids.device
841),
842full_attention_mask,
843),
844dim=-1,
845)
846if padding_mask is not None:
847full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
848if not past_length and padding_mask is not None:
849full_attention_mask -= padding_mask.unsqueeze(-1) - 1
850full_attention_mask = (full_attention_mask < 0.5).bool()
851full_attention_mask.unsqueeze_(1)
852return full_attention_mask
853
854def get_position_ids(self, input_ids, device):
855batch_size, seq_length = input_ids.shape
856position_ids = (
857torch.arange(seq_length, dtype=torch.long, device=device)
858.unsqueeze(0)
859.repeat(batch_size, 1)
860)
861return position_ids
862
863def _set_gradient_checkpointing(self, module, value=False):
864if isinstance(module, GLMTransformer):
865module.gradient_checkpointing = value
866
867
868class Embedding(torch.nn.Module):
869"""Language model embeddings."""
870
871def __init__(self, config: ChatGLMConfig, device=None):
872super(Embedding, self).__init__()
873
874self.hidden_size = config.hidden_size
875# Word embeddings (parallel).
876self.word_embeddings = nn.Embedding(
877config.padded_vocab_size,
878self.hidden_size,
879dtype=config.torch_dtype,
880device=device,
881)
882self.fp32_residual_connection = config.fp32_residual_connection
883
884def forward(self, input_ids):
885# Embeddings.
886words_embeddings = self.word_embeddings(input_ids)
887embeddings = words_embeddings
888# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
889embeddings = embeddings.transpose(0, 1).contiguous()
890# If the input flag for fp32 residual connection is set, convert for float.
891if self.fp32_residual_connection:
892embeddings = embeddings.float()
893return embeddings
894
895
896class ChatGLMModel(ChatGLMPreTrainedModel):
897def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
898super().__init__(config)
899if empty_init:
900init_method = skip_init
901else:
902init_method = default_init
903init_kwargs = {}
904if device is not None:
905init_kwargs["device"] = device
906self.embedding = init_method(Embedding, config, **init_kwargs)
907self.num_layers = config.num_layers
908self.multi_query_group_num = config.multi_query_group_num
909self.kv_channels = config.kv_channels
910
911# Rotary positional embeddings
912self.seq_length = config.seq_length
913rotary_dim = (
914config.hidden_size // config.num_attention_heads
915if config.kv_channels is None
916else config.kv_channels
917)
918
919self.rotary_pos_emb = RotaryEmbedding(
920rotary_dim // 2,
921original_impl=config.original_rope,
922device=device,
923dtype=config.torch_dtype,
924)
925self.encoder = init_method(GLMTransformer, config, **init_kwargs)
926self.output_layer = init_method(
927nn.Linear,
928config.hidden_size,
929config.padded_vocab_size,
930bias=False,
931dtype=config.torch_dtype,
932**init_kwargs,
933)
934self.pre_seq_len = config.pre_seq_len
935self.prefix_projection = config.prefix_projection
936if self.pre_seq_len is not None:
937for param in self.parameters():
938param.requires_grad = False
939self.prefix_tokens = torch.arange(self.pre_seq_len).long()
940self.prefix_encoder = PrefixEncoder(config)
941self.dropout = torch.nn.Dropout(0.1)
942
943def get_input_embeddings(self):
944return self.embedding.word_embeddings
945
946def get_prompt(self, batch_size, device, dtype=torch.half):
947prefix_tokens = (
948self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
949)
950past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
951past_key_values = past_key_values.view(
952batch_size,
953self.pre_seq_len,
954self.num_layers * 2,
955self.multi_query_group_num,
956self.kv_channels,
957)
958# seq_len, b, nh, hidden_size
959past_key_values = self.dropout(past_key_values)
960past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
961return past_key_values
962
963def forward(
964self,
965input_ids,
966position_ids: Optional[torch.Tensor] = None,
967attention_mask: Optional[torch.BoolTensor] = None,
968full_attention_mask: Optional[torch.BoolTensor] = None,
969past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
970inputs_embeds: Optional[torch.Tensor] = None,
971use_cache: Optional[bool] = None,
972output_hidden_states: Optional[bool] = None,
973return_dict: Optional[bool] = None,
974):
975output_hidden_states = (
976output_hidden_states
977if output_hidden_states is not None
978else self.config.output_hidden_states
979)
980use_cache = use_cache if use_cache is not None else self.config.use_cache
981return_dict = (
982return_dict if return_dict is not None else self.config.use_return_dict
983)
984
985batch_size, seq_length = input_ids.shape
986
987if inputs_embeds is None:
988inputs_embeds = self.embedding(input_ids)
989
990if self.pre_seq_len is not None:
991if past_key_values is None:
992past_key_values = self.get_prompt(
993batch_size=batch_size,
994device=input_ids.device,
995dtype=inputs_embeds.dtype,
996)
997if attention_mask is not None:
998attention_mask = torch.cat(
999[
1000attention_mask.new_ones((batch_size, self.pre_seq_len)),
1001attention_mask,
1002],
1003dim=-1,
1004)
1005
1006if full_attention_mask is None:
1007if (attention_mask is not None and not attention_mask.all()) or (
1008past_key_values and seq_length != 1
1009):
1010full_attention_mask = self.get_masks(
1011input_ids, past_key_values, padding_mask=attention_mask
1012)
1013
1014# Rotary positional embeddings
1015rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
1016if position_ids is not None:
1017rotary_pos_emb = rotary_pos_emb[position_ids]
1018else:
1019rotary_pos_emb = rotary_pos_emb[None, :seq_length]
1020rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
1021
1022# Run encoder.
1023hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
1024inputs_embeds,
1025full_attention_mask,
1026rotary_pos_emb=rotary_pos_emb,
1027kv_caches=past_key_values,
1028use_cache=use_cache,
1029output_hidden_states=output_hidden_states,
1030)
1031
1032if not return_dict:
1033return tuple(
1034v
1035for v in [
1036hidden_states,
1037presents,
1038all_hidden_states,
1039all_self_attentions,
1040]
1041if v is not None
1042)
1043
1044return BaseModelOutputWithPast(
1045last_hidden_state=hidden_states,
1046past_key_values=presents,
1047hidden_states=all_hidden_states,
1048attentions=all_self_attentions,
1049)
1050
1051def quantize(self, weight_bit_width: int):
1052from .quantization import quantize
1053
1054quantize(self.encoder, weight_bit_width)
1055return self
1056
1057
1058class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1059def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1060super().__init__(config)
1061
1062self.max_sequence_length = config.max_length
1063self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1064self.config = config
1065self.quantized = False
1066
1067if self.config.quantization_bit:
1068self.quantize(self.config.quantization_bit, empty_init=True)
1069
1070def _update_model_kwargs_for_generation(
1071self,
1072outputs: ModelOutput,
1073model_kwargs: Dict[str, Any],
1074is_encoder_decoder: bool = False,
1075standardize_cache_format: bool = False,
1076) -> Dict[str, Any]:
1077# update past_key_values
1078model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1079outputs, standardize_cache_format=standardize_cache_format
1080)
1081
1082# update attention mask
1083if "attention_mask" in model_kwargs:
1084attention_mask = model_kwargs["attention_mask"]
1085model_kwargs["attention_mask"] = torch.cat(
1086[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
1087dim=-1,
1088)
1089
1090# update position ids
1091if "position_ids" in model_kwargs:
1092position_ids = model_kwargs["position_ids"]
1093new_position_id = position_ids[..., -1:].clone()
1094new_position_id += 1
1095model_kwargs["position_ids"] = torch.cat(
1096[position_ids, new_position_id], dim=-1
1097)
1098
1099model_kwargs["is_first_forward"] = False
1100return model_kwargs
1101
1102def prepare_inputs_for_generation(
1103self,
1104input_ids: torch.LongTensor,
1105past_key_values: Optional[torch.Tensor] = None,
1106attention_mask: Optional[torch.Tensor] = None,
1107position_ids: Optional[torch.Tensor] = None,
1108use_cache: Optional[bool] = None,
1109is_first_forward: bool = True,
1110**kwargs,
1111) -> dict:
1112# only last token for input_ids if past is not None
1113if position_ids is None:
1114position_ids = self.get_position_ids(input_ids, device=input_ids.device)
1115if not is_first_forward:
1116if past_key_values is not None:
1117position_ids = position_ids[..., -1:]
1118input_ids = input_ids[:, -1:]
1119return {
1120"input_ids": input_ids,
1121"past_key_values": past_key_values,
1122"position_ids": position_ids,
1123"attention_mask": attention_mask,
1124"return_last_logit": True,
1125"use_cache": use_cache,
1126}
1127
1128def forward(
1129self,
1130input_ids: Optional[torch.Tensor] = None,
1131position_ids: Optional[torch.Tensor] = None,
1132attention_mask: Optional[torch.Tensor] = None,
1133past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1134inputs_embeds: Optional[torch.Tensor] = None,
1135labels: Optional[torch.Tensor] = None,
1136use_cache: Optional[bool] = None,
1137output_attentions: Optional[bool] = None,
1138output_hidden_states: Optional[bool] = None,
1139return_dict: Optional[bool] = None,
1140return_last_logit: Optional[bool] = False,
1141):
1142use_cache = use_cache if use_cache is not None else self.config.use_cache
1143return_dict = (
1144return_dict if return_dict is not None else self.config.use_return_dict
1145)
1146
1147transformer_outputs = self.transformer(
1148input_ids=input_ids,
1149position_ids=position_ids,
1150attention_mask=attention_mask,
1151past_key_values=past_key_values,
1152inputs_embeds=inputs_embeds,
1153use_cache=use_cache,
1154output_hidden_states=output_hidden_states,
1155return_dict=return_dict,
1156)
1157
1158hidden_states = transformer_outputs[0]
1159if return_last_logit:
1160hidden_states = hidden_states[-1:]
1161lm_logits = self.transformer.output_layer(hidden_states)
1162lm_logits = lm_logits.transpose(0, 1).contiguous()
1163
1164loss = None
1165if labels is not None:
1166lm_logits = lm_logits.to(torch.float32)
1167
1168# Shift so that tokens < n predict n
1169shift_logits = lm_logits[..., :-1, :].contiguous()
1170shift_labels = labels[..., 1:].contiguous()
1171# Flatten the tokens
1172loss_fct = CrossEntropyLoss(ignore_index=-100)
1173loss = loss_fct(
1174shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1175)
1176
1177lm_logits = lm_logits.to(hidden_states.dtype)
1178loss = loss.to(hidden_states.dtype)
1179
1180if not return_dict:
1181output = (lm_logits,) + transformer_outputs[1:]
1182return ((loss,) + output) if loss is not None else output
1183
1184return CausalLMOutputWithPast(
1185loss=loss,
1186logits=lm_logits,
1187past_key_values=transformer_outputs.past_key_values,
1188hidden_states=transformer_outputs.hidden_states,
1189attentions=transformer_outputs.attentions,
1190)
1191
1192@staticmethod
1193def _reorder_cache(
1194past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1195) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1196"""
1197This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1198[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1199beam_idx at every generation step.
1200Output shares the same memory storage as `past`.
1201"""
1202return tuple(
1203(
1204layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1205layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1206)
1207for layer_past in past
1208)
1209
1210def process_response(self, output, history):
1211content = ""
1212history = deepcopy(history)
1213for response in output.split("<|assistant|>"):
1214metadata, content = response.split("\n", maxsplit=1)
1215if not metadata.strip():
1216content = content.strip()
1217history.append(
1218{"role": "assistant", "metadata": metadata, "content": content}
1219)
1220content = content.replace("[[训练时间]]", "2023年")
1221else:
1222history.append(
1223{"role": "assistant", "metadata": metadata, "content": content}
1224)
1225if history[0]["role"] == "system" and "tools" in history[0]:
1226content = "\n".join(content.split("\n")[1:-1])
1227
1228def tool_call(**kwargs):
1229return kwargs
1230
1231parameters = eval(content)
1232content = {"name": metadata.strip(), "parameters": parameters}
1233else:
1234content = {"name": metadata.strip(), "content": content}
1235return content, history
1236
1237@torch.inference_mode()
1238def chat(
1239self,
1240tokenizer,
1241query: str,
1242history: List[Dict] = None,
1243role: str = "user",
1244max_length: int = 8192,
1245num_beams=1,
1246do_sample=True,
1247top_p=0.8,
1248temperature=0.8,
1249logits_processor=None,
1250**kwargs,
1251):
1252if history is None:
1253history = []
1254if logits_processor is None:
1255logits_processor = LogitsProcessorList()
1256logits_processor.append(InvalidScoreLogitsProcessor())
1257gen_kwargs = {
1258"max_length": max_length,
1259"num_beams": num_beams,
1260"do_sample": do_sample,
1261"top_p": top_p,
1262"temperature": temperature,
1263"logits_processor": logits_processor,
1264**kwargs,
1265}
1266inputs = tokenizer.build_chat_input(query, history=history, role=role)
1267inputs = inputs.to(self.device)
1268eos_token_id = [
1269tokenizer.eos_token_id,
1270tokenizer.get_command("<|user|>"),
1271tokenizer.get_command("<|observation|>"),
1272]
1273outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
1274outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
1275response = tokenizer.decode(outputs)
1276history.append({"role": role, "content": query})
1277response, history = self.process_response(response, history)
1278return response, history
1279
1280@torch.inference_mode()
1281def stream_chat(
1282self,
1283tokenizer,
1284query: str,
1285history: List[Dict] = None,
1286role: str = "user",
1287past_key_values=None,
1288max_length: int = 8192,
1289do_sample=True,
1290top_p=0.8,
1291temperature=0.8,
1292logits_processor=None,
1293return_past_key_values=False,
1294**kwargs,
1295):
1296if history is None:
1297history = []
1298if logits_processor is None:
1299logits_processor = LogitsProcessorList()
1300logits_processor.append(InvalidScoreLogitsProcessor())
1301eos_token_id = [
1302tokenizer.eos_token_id,
1303tokenizer.get_command("<|user|>"),
1304tokenizer.get_command("<|observation|>"),
1305]
1306gen_kwargs = {
1307"max_length": max_length,
1308"do_sample": do_sample,
1309"top_p": top_p,
1310"temperature": temperature,
1311"logits_processor": logits_processor,
1312**kwargs,
1313}
1314if past_key_values is None:
1315inputs = tokenizer.build_chat_input(query, history=history, role=role)
1316else:
1317inputs = tokenizer.build_chat_input(query, role=role)
1318inputs = inputs.to(self.device)
1319if past_key_values is not None:
1320past_length = past_key_values[0][0].shape[0]
1321if self.transformer.pre_seq_len is not None:
1322past_length -= self.transformer.pre_seq_len
1323inputs.position_ids += past_length
1324attention_mask = inputs.attention_mask
1325attention_mask = torch.cat(
1326(attention_mask.new_ones(1, past_length), attention_mask), dim=1
1327)
1328inputs["attention_mask"] = attention_mask
1329history.append({"role": role, "content": query})
1330for outputs in self.stream_generate(
1331**inputs,
1332past_key_values=past_key_values,
1333eos_token_id=eos_token_id,
1334return_past_key_values=return_past_key_values,
1335**gen_kwargs,
1336):
1337if return_past_key_values:
1338outputs, past_key_values = outputs
1339outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) : -1]
1340response = tokenizer.decode(outputs)
1341if response and response[-1] != "�":
1342response, new_history = self.process_response(response, history)
1343if return_past_key_values:
1344yield response, new_history, past_key_values
1345else:
1346yield response, new_history
1347
1348@torch.inference_mode()
1349def stream_generate(
1350self,
1351input_ids,
1352generation_config: Optional[GenerationConfig] = None,
1353logits_processor: Optional[LogitsProcessorList] = None,
1354stopping_criteria: Optional[StoppingCriteriaList] = None,
1355prefix_allowed_tokens_fn: Optional[
1356Callable[[int, torch.Tensor], List[int]]
1357] = None,
1358return_past_key_values=False,
1359**kwargs,
1360):
1361batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1362
1363if generation_config is None:
1364generation_config = self.generation_config
1365generation_config = copy.deepcopy(generation_config)
1366model_kwargs = generation_config.update(**kwargs)
1367model_kwargs["use_cache"] = generation_config.use_cache
1368bos_token_id, eos_token_id = (
1369generation_config.bos_token_id,
1370generation_config.eos_token_id,
1371)
1372
1373if isinstance(eos_token_id, int):
1374eos_token_id = [eos_token_id]
1375assert eos_token_id is not None
1376eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device)
1377
1378has_default_max_length = (
1379kwargs.get("max_length") is None
1380and generation_config.max_length is not None
1381)
1382if has_default_max_length and generation_config.max_new_tokens is None:
1383warnings.warn(
1384f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1385"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1386" recommend using `max_new_tokens` to control the maximum length of the generation.",
1387UserWarning,
1388)
1389elif generation_config.max_new_tokens is not None:
1390generation_config.max_length = (
1391generation_config.max_new_tokens + input_ids_seq_length
1392)
1393if not has_default_max_length:
1394logger.warn(
1395f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1396f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1397"Please refer to the documentation for more information. "
1398"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1399UserWarning,
1400)
1401
1402if input_ids_seq_length >= generation_config.max_length:
1403input_ids_string = (
1404"decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1405)
1406logger.warning(
1407f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1408f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1409" increasing `max_new_tokens`."
1410)
1411
1412# 2. Set generation parameters if not already defined
1413logits_processor = (
1414logits_processor if logits_processor is not None else LogitsProcessorList()
1415)
1416stopping_criteria = (
1417stopping_criteria
1418if stopping_criteria is not None
1419else StoppingCriteriaList()
1420)
1421
1422logits_processor = self._get_logits_processor(
1423generation_config=generation_config,
1424input_ids_seq_length=input_ids_seq_length,
1425encoder_input_ids=input_ids,
1426prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1427logits_processor=logits_processor,
1428)
1429
1430stopping_criteria = self._get_stopping_criteria(
1431generation_config=generation_config, stopping_criteria=stopping_criteria
1432)
1433logits_warper = self._get_logits_warper(generation_config)
1434
1435unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1436scores = None
1437while True:
1438model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1439# forward pass to get next token
1440outputs = self(
1441**model_inputs,
1442return_dict=True,
1443output_attentions=False,
1444output_hidden_states=False,
1445)
1446
1447next_token_logits = outputs.logits[:, -1, :]
1448
1449# pre-process distribution
1450next_token_scores = logits_processor(input_ids, next_token_logits)
1451next_token_scores = logits_warper(input_ids, next_token_scores)
1452
1453# sample
1454probs = nn.functional.softmax(next_token_scores, dim=-1)
1455if generation_config.do_sample:
1456next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1457else:
1458next_tokens = torch.argmax(probs, dim=-1)
1459# update generated ids, model inputs, and length for next step
1460input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1461model_kwargs = self._update_model_kwargs_for_generation(
1462outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1463)
1464unfinished_sequences = unfinished_sequences.mul(
1465next_tokens.tile(eos_token_id_tensor.shape[0], 1)
1466.ne(eos_token_id_tensor.unsqueeze(1))
1467.prod(dim=0)
1468)
1469if return_past_key_values:
1470yield input_ids, outputs.past_key_values
1471else:
1472yield input_ids
1473# stop when each sentence is finished, or if we exceed the maximum length
1474if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1475break
1476
1477def quantize(self, bits: int, empty_init=False, device=None, **kwargs):
1478if bits == 0:
1479return
1480
1481from .quantization import quantize
1482
1483if self.quantized:
1484logger.info("Already quantized.")
1485return self
1486
1487self.quantized = True
1488
1489self.config.quantization_bit = bits
1490
1491self.transformer.encoder = quantize(
1492self.transformer.encoder,
1493bits,
1494empty_init=empty_init,
1495device=device,
1496**kwargs,
1497)
1498return self
1499
1500
1501class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel):
1502def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
1503super().__init__(config)
1504
1505self.num_labels = config.num_labels
1506self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
1507
1508self.classifier_head = nn.Linear(
1509config.hidden_size, config.num_labels, bias=True, dtype=torch.half
1510)
1511if config.classifier_dropout is not None:
1512self.dropout = nn.Dropout(config.classifier_dropout)
1513else:
1514self.dropout = None
1515self.config = config
1516
1517if self.config.quantization_bit:
1518self.quantize(self.config.quantization_bit, empty_init=True)
1519
1520def forward(
1521self,
1522input_ids: Optional[torch.LongTensor] = None,
1523position_ids: Optional[torch.LongTensor] = None,
1524attention_mask: Optional[torch.Tensor] = None,
1525full_attention_mask: Optional[torch.Tensor] = None,
1526past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1527inputs_embeds: Optional[torch.LongTensor] = None,
1528labels: Optional[torch.LongTensor] = None,
1529use_cache: Optional[bool] = None,
1530output_hidden_states: Optional[bool] = None,
1531return_dict: Optional[bool] = None,
1532) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]:
1533return_dict = (
1534return_dict if return_dict is not None else self.config.use_return_dict
1535)
1536
1537transformer_outputs = self.transformer(
1538input_ids=input_ids,
1539position_ids=position_ids,
1540attention_mask=attention_mask,
1541full_attention_mask=full_attention_mask,
1542past_key_values=past_key_values,
1543inputs_embeds=inputs_embeds,
1544use_cache=use_cache,
1545output_hidden_states=output_hidden_states,
1546return_dict=return_dict,
1547)
1548
1549hidden_states = transformer_outputs[0]
1550pooled_hidden_states = hidden_states[-1]
1551if self.dropout is not None:
1552pooled_hidden_states = self.dropout(pooled_hidden_states)
1553logits = self.classifier_head(pooled_hidden_states)
1554
1555loss = None
1556if labels is not None:
1557if self.config.problem_type is None:
1558if self.num_labels == 1:
1559self.config.problem_type = "regression"
1560elif self.num_labels > 1 and (
1561labels.dtype == torch.long or labels.dtype == torch.int
1562):
1563self.config.problem_type = "single_label_classification"
1564else:
1565self.config.problem_type = "multi_label_classification"
1566
1567if self.config.problem_type == "regression":
1568loss_fct = MSELoss()
1569if self.num_labels == 1:
1570loss = loss_fct(logits.squeeze().float(), labels.squeeze())
1571else:
1572loss = loss_fct(logits.float(), labels)
1573elif self.config.problem_type == "single_label_classification":
1574loss_fct = CrossEntropyLoss()
1575loss = loss_fct(
1576logits.view(-1, self.num_labels).float(), labels.view(-1)
1577)
1578elif self.config.problem_type == "multi_label_classification":
1579loss_fct = BCEWithLogitsLoss()
1580loss = loss_fct(logits.float(), labels.view(-1, self.num_labels))
1581
1582if not return_dict:
1583output = (logits,) + transformer_outputs[1:]
1584return ((loss,) + output) if loss is not None else output
1585
1586return SequenceClassifierOutputWithPast(
1587loss=loss,
1588logits=logits,
1589past_key_values=transformer_outputs.past_key_values,
1590hidden_states=transformer_outputs.hidden_states,
1591attentions=transformer_outputs.attentions,
1592)
1593