colossalai
1477 строк · 56.2 Кб
1"""
2This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py
3"""
4
5""" PyTorch ChatGLM model. """
6
7import copy8import math9import os10import re11import sys12import warnings13from typing import Any, Callable, Dict, List, Optional, Tuple, Union14
15import torch16import torch.nn.functional as F17import torch.utils.checkpoint18from torch import nn19from torch.nn import CrossEntropyLoss, LayerNorm20from torch.nn.utils import skip_init21from transformers.generation.logits_process import LogitsProcessor22from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList23from transformers.modeling_outputs import (24BaseModelOutputWithPast,25BaseModelOutputWithPastAndCrossAttentions,26CausalLMOutputWithPast,27)
28from transformers.modeling_utils import PreTrainedModel29from transformers.utils import (30add_code_sample_docstrings,31add_start_docstrings,32add_start_docstrings_to_model_forward,33logging,34)
35
36from .configuration_chatglm import ChatGLMConfig37
38# flags required to enable jit fusion kernels
39
40if sys.platform != "darwin":41torch._C._jit_set_profiling_mode(False)42torch._C._jit_set_profiling_executor(False)43torch._C._jit_override_can_fuse_on_cpu(True)44torch._C._jit_override_can_fuse_on_gpu(True)45
46logger = logging.get_logger(__name__)47
48_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"49_CONFIG_FOR_DOC = "ChatGLM6BConfig"50
51CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [52"THUDM/chatglm-6b",53# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm54]
55
56
57class InvalidScoreLogitsProcessor(LogitsProcessor):58def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:59if torch.isnan(scores).any() or torch.isinf(scores).any():60scores.zero_()61scores[..., 5] = 5e462return scores63
64
65def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):66"""Load tf checkpoints in a pytorch model."""67try:68import re69
70import numpy as np71import tensorflow as tf72except ImportError:73logger.error(74"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "75"https://www.tensorflow.org/install/ for installation instructions."76)77raise78tf_path = os.path.abspath(tf_checkpoint_path)79logger.info(f"Converting TensorFlow checkpoint from {tf_path}")80# Load weights from TF model81init_vars = tf.train.list_variables(tf_path)82names = []83arrays = []84for name, shape in init_vars:85logger.info(f"Loading TF weight {name} with shape {shape}")86array = tf.train.load_variable(tf_path, name)87names.append(name)88arrays.append(array)89
90for name, array in zip(names, arrays):91name = name.split("/")92# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v93# which are not required for using pretrained model94if any(95n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]96for n in name97):98logger.info(f"Skipping {'/'.join(name)}")99continue100pointer = model101for m_name in name:102if re.fullmatch(r"[A-Za-z]+_\d+", m_name):103scope_names = re.split(r"_(\d+)", m_name)104else:105scope_names = [m_name]106if scope_names[0] == "kernel" or scope_names[0] == "gamma":107pointer = getattr(pointer, "weight")108elif scope_names[0] == "output_bias" or scope_names[0] == "beta":109pointer = getattr(pointer, "bias")110elif scope_names[0] == "output_weights":111pointer = getattr(pointer, "weight")112elif scope_names[0] == "squad":113pointer = getattr(pointer, "classifier")114else:115try:116pointer = getattr(pointer, scope_names[0])117except AttributeError:118logger.info(f"Skipping {'/'.join(name)}")119continue120if len(scope_names) >= 2:121num = int(scope_names[1])122pointer = pointer[num]123if m_name[-11:] == "_embeddings":124pointer = getattr(pointer, "weight")125elif m_name == "kernel":126array = np.transpose(array)127try:128assert (129pointer.shape == array.shape130), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"131except AssertionError as e:132e.args += (pointer.shape, array.shape)133raise134logger.info(f"Initialize PyTorch weight {name}")135pointer.data = torch.from_numpy(array)136return model137
138
139class PrefixEncoder(torch.nn.Module):140"""141The torch.nn model to encode the prefix
142Input shape: (batch-size, prefix-length)
143Output shape: (batch-size, prefix-length, 2*layers*hidden)
144"""
145
146def __init__(self, config):147super().__init__()148self.prefix_projection = config.prefix_projection149if self.prefix_projection:150# Use a two-layer MLP to encode the prefix151self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)152self.trans = torch.nn.Sequential(153torch.nn.Linear(config.hidden_size, config.hidden_size),154torch.nn.Tanh(),155torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),156)157else:158self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)159
160def forward(self, prefix: torch.Tensor):161if self.prefix_projection:162prefix_tokens = self.embedding(prefix)163past_key_values = self.trans(prefix_tokens)164else:165past_key_values = self.embedding(prefix)166return past_key_values167
168
169@torch.jit.script170def gelu_impl(x):171"""OpenAI's gelu implementation."""172return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))173
174
175def gelu(x):176return gelu_impl(x)177
178
179class RotaryEmbedding(torch.nn.Module):180def __init__(self, dim, base=10000, precision=torch.half, learnable=False):181super().__init__()182inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))183inv_freq = inv_freq.half()184self.learnable = learnable185if learnable:186self.inv_freq = torch.nn.Parameter(inv_freq)187self.max_seq_len_cached = None188else:189self.register_buffer("inv_freq", inv_freq)190self.max_seq_len_cached = None191self.cos_cached = None192self.sin_cached = None193self.precision = precision194
195def _load_from_state_dict(196self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs197):198pass199
200def forward(self, x, seq_dim=1, seq_len=None):201if seq_len is None:202seq_len = x.shape[seq_dim]203if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):204self.max_seq_len_cached = None if self.learnable else seq_len205t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)206freqs = torch.einsum("i,j->ij", t, self.inv_freq)207# Different from paper, but it uses a different permutation in order to obtain the same calculation208emb = torch.cat((freqs, freqs), dim=-1).to(x.device)209if self.precision == torch.bfloat16:210emb = emb.float()211
212# [sx, 1 (b * np), hn]213cos_cached = emb.cos()[:, None, :]214sin_cached = emb.sin()[:, None, :]215if self.precision == torch.bfloat16:216cos_cached = cos_cached.bfloat16()217sin_cached = sin_cached.bfloat16()218if self.learnable:219return cos_cached, sin_cached220self.cos_cached, self.sin_cached = cos_cached, sin_cached221return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]222
223def _apply(self, fn):224if self.cos_cached is not None:225self.cos_cached = fn(self.cos_cached)226if self.sin_cached is not None:227self.sin_cached = fn(self.sin_cached)228return super()._apply(fn)229
230
231def rotate_half(x):232x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]233return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions234
235
236@torch.jit.script237def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):238# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]239cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(240position_id, sin.squeeze(1)241).unsqueeze(2)242q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)243return q, k244
245
246def attention_fn(247self,248query_layer,249key_layer,250value_layer,251attention_mask,252hidden_size_per_partition,253layer_id,254layer_past=None,255scaling_attention_score=True,256use_cache=False,257):258if layer_past is not None:259past_key, past_value = layer_past[0], layer_past[1]260key_layer = torch.cat((past_key, key_layer), dim=0)261value_layer = torch.cat((past_value, value_layer), dim=0)262
263# seqlen, batch, num_attention_heads, hidden_size_per_attention_head264seq_len, b, nh, hidden_size = key_layer.shape265
266if use_cache:267present = (key_layer, value_layer)268else:269present = None270
271query_key_layer_scaling_coeff = float(layer_id + 1)272if scaling_attention_score:273query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)274
275# ===================================276# Raw attention scores. [b, np, s, s]277# ===================================278
279# [b, np, sq, sk]280output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))281
282# [sq, b, np, hn] -> [sq, b * np, hn]283query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)284# [sk, b, np, hn] -> [sk, b * np, hn]285key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)286
287matmul_result = torch.zeros(2881,2891,2901,291dtype=query_layer.dtype,292device=query_layer.device,293)294
295matmul_result = torch.baddbmm(296matmul_result,297query_layer.transpose(0, 1), # [b * np, sq, hn]298key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]299beta=0.0,300alpha=1.0,301)302
303# change view to [b, np, sq, sk]304attention_scores = matmul_result.view(*output_size)305
306if self.scale_mask_softmax:307self.scale_mask_softmax.scale = query_key_layer_scaling_coeff308attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())309else:310if not (attention_mask == 0).all():311# if auto-regressive, skip312attention_scores.masked_fill_(attention_mask, -10000.0)313dtype = attention_scores.dtype314attention_scores = attention_scores.float()315attention_scores = attention_scores * query_key_layer_scaling_coeff316
317attention_probs = F.softmax(attention_scores, dim=-1)318
319attention_probs = attention_probs.type(dtype)320
321# =========================322# Context layer. [sq, b, hp]323# =========================324
325# value_layer -> context layer.326# [sk, b, np, hn] --> [b, np, sq, hn]327
328# context layer shape: [b, np, sq, hn]329output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))330
331# change view [sk, b * np, hn]332value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)333
334# change view [b * np, sq, sk]335attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)336
337# matmul: [b * np, sq, hn]338context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))339
340# change view [b, np, sq, hn]341context_layer = context_layer.view(*output_size)342
343# [b, np, sq, hn] --> [sq, b, np, hn]344context_layer = context_layer.permute(2, 0, 1, 3).contiguous()345
346# [sq, b, np, hn] --> [sq, b, hp]347new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)348context_layer = context_layer.view(*new_context_layer_shape)349
350outputs = (context_layer, present, attention_probs)351
352return outputs353
354
355def default_init(cls, *args, **kwargs):356return cls(*args, **kwargs)357
358
359class SelfAttention(torch.nn.Module):360def __init__(361self,362hidden_size,363num_attention_heads,364layer_id,365hidden_size_per_attention_head=None,366bias=True,367params_dtype=torch.float,368position_encoding_2d=True,369empty_init=True,370):371if empty_init:372init_method = skip_init373else:374init_method = default_init375super(SelfAttention, self).__init__()376
377self.layer_id = layer_id378self.hidden_size = hidden_size379self.hidden_size_per_partition = hidden_size380self.num_attention_heads = num_attention_heads381self.num_attention_heads_per_partition = num_attention_heads382self.position_encoding_2d = position_encoding_2d383self.rotary_emb = RotaryEmbedding(384self.hidden_size // (self.num_attention_heads * 2)385if position_encoding_2d386else self.hidden_size // self.num_attention_heads,387base=10000,388precision=torch.half,389learnable=False,390)391
392self.scale_mask_softmax = None393
394if hidden_size_per_attention_head is None:395self.hidden_size_per_attention_head = hidden_size // num_attention_heads396else:397self.hidden_size_per_attention_head = hidden_size_per_attention_head398
399self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head400
401# Strided linear layer.402self.query_key_value = init_method(403torch.nn.Linear,404hidden_size,4053 * self.inner_hidden_size,406bias=bias,407dtype=params_dtype,408)409
410self.dense = init_method(411torch.nn.Linear,412self.inner_hidden_size,413hidden_size,414bias=bias,415dtype=params_dtype,416)417
418@staticmethod419def attention_mask_func(attention_scores, attention_mask):420attention_scores.masked_fill_(attention_mask, -10000.0)421return attention_scores422
423def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):424"""Split a tensor along its last dimension.425Arguments:
426tensor: input tensor.
427num_partitions: number of partitions to split the tensor
428contiguous_split_chunks: If True, make each chunk contiguous
429in memory.
430"""
431# Get the size and dimension.432last_dim = tensor.dim() - 1433last_dim_size = tensor.size()[last_dim] // num_partitions434# Split.435tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)436# Note: torch.split does not create contiguous tensors by default.437if contiguous_split_chunks:438return tuple(chunk.contiguous() for chunk in tensor_list)439
440return tensor_list441
442def forward(443self,444hidden_states: torch.Tensor,445position_ids,446attention_mask: torch.Tensor,447layer_id,448layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,449use_cache: bool = False,450output_attentions: bool = False,451):452"""453hidden_states: [seq_len, batch, hidden_size]
454attention_mask: [(1, 1), seq_len, seq_len]
455"""
456
457# [seq_len, batch, 3 * hidden_size]458mixed_raw_layer = self.query_key_value(hidden_states)459
460# [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]461new_tensor_shape = mixed_raw_layer.size()[:-1] + (462self.num_attention_heads_per_partition,4633 * self.hidden_size_per_attention_head,464)465mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)466
467# [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]468(query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)469
470if self.position_encoding_2d:471q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))472k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))473cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)474position_ids, block_position_ids = (475position_ids[:, 0, :].transpose(0, 1).contiguous(),476position_ids[:, 1, :].transpose(0, 1).contiguous(),477)478q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)479q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)480query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))481key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))482else:483position_ids = position_ids.transpose(0, 1)484cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)485# [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]486query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)487
488# [seq_len, batch, hidden_size]489context_layer, present, attention_probs = attention_fn(490self=self,491query_layer=query_layer,492key_layer=key_layer,493value_layer=value_layer,494attention_mask=attention_mask,495hidden_size_per_partition=self.hidden_size_per_partition,496layer_id=layer_id,497layer_past=layer_past,498use_cache=use_cache,499)500
501output = self.dense(context_layer)502
503outputs = (output, present)504
505if output_attentions:506outputs += (attention_probs,)507
508return outputs # output, present, attention_probs509
510
511class GEGLU(torch.nn.Module):512def __init__(self):513super().__init__()514self.activation_fn = F.gelu515
516def forward(self, x):517# dim=-1 breaks in jit for pt<1.10518x1, x2 = x.chunk(2, dim=(x.ndim - 1))519return x1 * self.activation_fn(x2)520
521
522class GLU(torch.nn.Module):523def __init__(524self,525hidden_size,526inner_hidden_size=None,527layer_id=None,528bias=True,529activation_func=gelu,530params_dtype=torch.float,531empty_init=True,532):533super(GLU, self).__init__()534if empty_init:535init_method = skip_init536else:537init_method = default_init538self.layer_id = layer_id539self.activation_func = activation_func540
541# Project to 4h.542self.hidden_size = hidden_size543if inner_hidden_size is None:544inner_hidden_size = 4 * hidden_size545self.inner_hidden_size = inner_hidden_size546self.dense_h_to_4h = init_method(547torch.nn.Linear,548self.hidden_size,549self.inner_hidden_size,550bias=bias,551dtype=params_dtype,552)553# Project back to h.554self.dense_4h_to_h = init_method(555torch.nn.Linear,556self.inner_hidden_size,557self.hidden_size,558bias=bias,559dtype=params_dtype,560)561
562def forward(self, hidden_states):563"""564hidden_states: [seq_len, batch, hidden_size]
565"""
566
567# [seq_len, batch, inner_hidden_size]568intermediate_parallel = self.dense_h_to_4h(hidden_states)569
570intermediate_parallel = self.activation_func(intermediate_parallel)571
572output = self.dense_4h_to_h(intermediate_parallel)573
574return output575
576
577class GLMBlock(torch.nn.Module):578def __init__(579self,580hidden_size,581num_attention_heads,582layernorm_epsilon,583layer_id,584inner_hidden_size=None,585hidden_size_per_attention_head=None,586layernorm=LayerNorm,587use_bias=True,588params_dtype=torch.float,589num_layers=28,590position_encoding_2d=True,591empty_init=True,592):593super(GLMBlock, self).__init__()594# Set output layer initialization if not provided.595
596self.layer_id = layer_id597
598# Layernorm on the input data.599self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)600
601self.position_encoding_2d = position_encoding_2d602
603# Self attention.604self.attention = SelfAttention(605hidden_size,606num_attention_heads,607layer_id,608hidden_size_per_attention_head=hidden_size_per_attention_head,609bias=use_bias,610params_dtype=params_dtype,611position_encoding_2d=self.position_encoding_2d,612empty_init=empty_init,613)614
615# Layernorm on the input data.616self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)617
618self.num_layers = num_layers619
620# GLU621self.mlp = GLU(622hidden_size,623inner_hidden_size=inner_hidden_size,624bias=use_bias,625layer_id=layer_id,626params_dtype=params_dtype,627empty_init=empty_init,628)629
630def forward(631self,632hidden_states: torch.Tensor,633position_ids,634attention_mask: torch.Tensor,635layer_id,636layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,637use_cache: bool = False,638output_attentions: bool = False,639):640"""641hidden_states: [seq_len, batch, hidden_size]
642attention_mask: [(1, 1), seq_len, seq_len]
643"""
644
645# Layer norm at the begining of the transformer layer.646# [seq_len, batch, hidden_size]647attention_input = self.input_layernorm(hidden_states)648
649# Self attention.650attention_outputs = self.attention(651attention_input,652position_ids,653attention_mask=attention_mask,654layer_id=layer_id,655layer_past=layer_past,656use_cache=use_cache,657output_attentions=output_attentions,658)659
660attention_output = attention_outputs[0]661
662outputs = attention_outputs[1:]663
664# Residual connection.665alpha = (2 * self.num_layers) ** 0.5666hidden_states = attention_input * alpha + attention_output667
668mlp_input = self.post_attention_layernorm(hidden_states)669
670# MLP.671mlp_output = self.mlp(mlp_input)672
673# Second residual connection.674output = mlp_input * alpha + mlp_output675
676if use_cache:677outputs = (output,) + outputs678else:679outputs = (output,) + outputs[1:]680
681return outputs # hidden_states, present, attentions682
683
684class ChatGLMPreTrainedModel(PreTrainedModel):685"""686An abstract class to handle weights initialization and
687a simple interface for downloading and loading pretrained models.
688"""
689
690is_parallelizable = False691supports_gradient_checkpointing = True692config_class = ChatGLMConfig693base_model_prefix = "transformer"694_no_split_modules = ["GLMBlock"]695
696def __init__(self, *inputs, **kwargs):697super().__init__(*inputs, **kwargs)698
699def _init_weights(self, module: nn.Module):700"""Initialize the weights."""701return702
703def get_masks(self, input_ids, device):704batch_size, seq_length = input_ids.shape705context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]706attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)707attention_mask.tril_()708for i, context_length in enumerate(context_lengths):709attention_mask[i, :, :context_length] = 1710attention_mask.unsqueeze_(1)711attention_mask = (attention_mask < 0.5).bool()712
713return attention_mask714
715def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):716batch_size, seq_length = input_ids.shape717if use_gmasks is None:718use_gmasks = [False] * batch_size719context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]720if self.position_encoding_2d:721position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)722for i, context_length in enumerate(context_lengths):723position_ids[i, context_length:] = mask_positions[i]724block_position_ids = [725torch.cat(726(727torch.zeros(context_length, dtype=torch.long, device=device),728torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,729)730)731for context_length in context_lengths732]733block_position_ids = torch.stack(block_position_ids, dim=0)734position_ids = torch.stack((position_ids, block_position_ids), dim=1)735else:736position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)737for i, context_length in enumerate(context_lengths):738if not use_gmasks[i]:739position_ids[i, context_length:] = mask_positions[i]740
741return position_ids742
743def _set_gradient_checkpointing(self, module, value=False):744if isinstance(module, ChatGLMModel):745module.gradient_checkpointing = value746
747
748CHATGLM_6B_START_DOCSTRING = r"""749This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
750Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
751usage and behavior.
752
753Parameters:
754config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
755Initializing with a config file does not load the weights associated with the model, only the configuration.
756Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
757"""
758
759CHATGLM_6B_INPUTS_DOCSTRING = r"""760Args:
761input_ids (`torch.LongTensor` of shape `({0})`):
762Indices of input sequence tokens in the vocabulary.
763
764Indices can be obtained using [`ChatGLM6BTokenizer`].
765See [`PreTrainedTokenizer.encode`] and
766[`PreTrainedTokenizer.__call__`] for details.
767
768[What are input IDs?](../glossary#input-ids)
769attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
770Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
771
772- 1 for tokens that are **not masked**,
773- 0 for tokens that are **masked**.
774
775[What are attention masks?](../glossary#attention-mask)
776token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
777Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
778
779- 0 corresponds to a *sentence A* token,
780- 1 corresponds to a *sentence B* token.
781
782[What are token type IDs?](../glossary#token-type-ids)
783position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
784Indices of positions of each input sequence tokens in the position embeddings.
785Selected in the range `[0, config.max_position_embeddings - 1]`.
786
787[What are position IDs?](../glossary#position-ids)
788head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
789Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
790
791- 1 indicates the head is **not masked**,
792- 0 indicates the head is **masked**.
793
794inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
795Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
796This is useful if you want more control over how to convert *input_ids* indices into associated vectors
797than the model's internal embedding lookup matrix.
798output_attentions (`bool`, *optional*):
799Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
800tensors for more detail.
801output_hidden_states (`bool`, *optional*):
802Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
803more detail.
804return_dict (`bool`, *optional*):
805Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
806"""
807
808
809@add_start_docstrings(810"The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",811CHATGLM_6B_START_DOCSTRING,812)
813class ChatGLMModel(ChatGLMPreTrainedModel):814"""815
816The model can behave as an encoder (with only self-attention) as well
817as a decoder, in which case a layer of cross-attention is added between
818the self-attention layers, following the architecture described in [Attention is
819all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
820Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
821
822To behave as an decoder the model needs to be initialized with the
823`is_decoder` argument of the configuration set to `True`.
824To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
825argument and `add_cross_attention` set to `True`; an
826`encoder_hidden_states` is then expected as an input to the forward pass.
827"""
828
829def __init__(self, config: ChatGLMConfig, empty_init=True):830super().__init__(config)831if empty_init:832init_method = skip_init833else:834init_method = default_init835# recording parameters836self.max_sequence_length = config.max_sequence_length837self.hidden_size = config.hidden_size838self.params_dtype = torch.half839self.num_attention_heads = config.num_attention_heads840self.vocab_size = config.vocab_size841self.num_layers = config.num_layers842self.layernorm_epsilon = config.layernorm_epsilon843self.inner_hidden_size = config.inner_hidden_size844self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads845self.position_encoding_2d = config.position_encoding_2d846self.pre_seq_len = config.pre_seq_len847self.prefix_projection = config.prefix_projection848
849self.word_embeddings = init_method(850torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype851)852self.gradient_checkpointing = False853
854def get_layer(layer_id):855return GLMBlock(856self.hidden_size,857self.num_attention_heads,858self.layernorm_epsilon,859layer_id,860inner_hidden_size=self.inner_hidden_size,861hidden_size_per_attention_head=self.hidden_size_per_attention_head,862layernorm=LayerNorm,863use_bias=True,864params_dtype=self.params_dtype,865position_encoding_2d=self.position_encoding_2d,866empty_init=empty_init,867)868
869self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])870
871# Final layer norm before output.872self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)873
874if self.pre_seq_len is not None:875for param in self.parameters():876param.requires_grad = False877self.prefix_tokens = torch.arange(self.pre_seq_len).long()878self.prefix_encoder = PrefixEncoder(config)879self.dropout = torch.nn.Dropout(0.1)880
881# total_params = sum(p.numel() for p in self.parameters())882# trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)883# print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))884
885def get_input_embeddings(self):886return self.word_embeddings887
888def set_input_embeddings(self, new_embeddings: torch.Tensor):889self.word_embeddings = new_embeddings890
891def get_prompt(self, batch_size, device, dtype=torch.half):892prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)893past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)894past_key_values = past_key_values.view(895batch_size,896self.pre_seq_len,897self.num_layers * 2,898self.num_attention_heads,899self.hidden_size // self.num_attention_heads,900)901# seq_len, b, nh, hidden_size902past_key_values = self.dropout(past_key_values)903past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)904# past_key_values = [(v[0], v[1]) for v in past_key_values]905return past_key_values906
907@add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))908@add_code_sample_docstrings(909checkpoint=_CHECKPOINT_FOR_DOC,910output_type=BaseModelOutputWithPastAndCrossAttentions,911config_class=_CONFIG_FOR_DOC,912)913def forward(914self,915input_ids: Optional[torch.LongTensor] = None,916position_ids: Optional[torch.LongTensor] = None,917attention_mask: Optional[torch.Tensor] = None,918past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,919inputs_embeds: Optional[torch.LongTensor] = None,920use_cache: Optional[bool] = None,921output_attentions: Optional[bool] = None,922output_hidden_states: Optional[bool] = None,923return_dict: Optional[bool] = None,924) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:925output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions926output_hidden_states = (927output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states928)929use_cache = use_cache if use_cache is not None else self.config.use_cache930return_dict = return_dict if return_dict is not None else self.config.use_return_dict931
932if self.gradient_checkpointing and self.training:933if use_cache:934logger.warning_once(935"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."936)937use_cache = False938
939if input_ids is not None and inputs_embeds is not None:940raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")941elif input_ids is not None:942batch_size, seq_length = input_ids.shape[:2]943elif inputs_embeds is not None:944batch_size, seq_length = inputs_embeds.shape[:2]945else:946raise ValueError("You have to specify either input_ids or inputs_embeds")947
948if inputs_embeds is None:949inputs_embeds = self.word_embeddings(input_ids)950
951if past_key_values is None:952if self.pre_seq_len is not None:953past_key_values = self.get_prompt(954batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype955)956else:957past_key_values = tuple([None] * len(self.layers))958
959if attention_mask is None:960attention_mask = self.get_masks(input_ids, device=input_ids.device)961
962if position_ids is None:963MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id964seqs = input_ids.tolist()965
966mask_positions, use_gmasks = [], []967for seq in seqs:968mask_token = gMASK if gMASK in seq else MASK969use_gmask = mask_token == gMASK970mask_positions.append(seq.index(mask_token))971use_gmasks.append(use_gmask)972
973position_ids = self.get_position_ids(974input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks975)976
977if self.pre_seq_len is not None and attention_mask is not None:978prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(979attention_mask.device980)981prefix_attention_mask = (prefix_attention_mask < 0.5).bool()982attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)983
984# [seq_len, batch, hidden_size]985hidden_states = inputs_embeds.transpose(0, 1)986
987presents = () if use_cache else None988all_self_attentions = () if output_attentions else None989all_hidden_states = () if output_hidden_states else None990
991if attention_mask is None:992attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()993else:994attention_mask = attention_mask.to(hidden_states.device)995
996for i, layer in enumerate(self.layers):997if output_hidden_states:998all_hidden_states = all_hidden_states + (hidden_states,)999layer_past = past_key_values[i]1000
1001if self.gradient_checkpointing and self.training:1002layer_ret = torch.utils.checkpoint.checkpoint(1003layer,1004hidden_states,1005position_ids,1006attention_mask,1007torch.tensor(i),1008layer_past,1009use_cache,1010output_attentions,1011)1012else:1013layer_ret = layer(1014hidden_states,1015position_ids=position_ids,1016attention_mask=attention_mask,1017layer_id=torch.tensor(i),1018layer_past=layer_past,1019use_cache=use_cache,1020output_attentions=output_attentions,1021)1022
1023hidden_states = layer_ret[0]1024
1025if use_cache:1026presents = presents + (layer_ret[1],)1027
1028if output_attentions:1029all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)1030
1031# Final layer norm.1032hidden_states = self.final_layernorm(hidden_states)1033
1034if output_hidden_states:1035all_hidden_states = all_hidden_states + (hidden_states,)1036
1037if not return_dict:1038return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)1039
1040return BaseModelOutputWithPast(1041last_hidden_state=hidden_states,1042past_key_values=presents,1043hidden_states=all_hidden_states,1044attentions=all_self_attentions,1045)1046
1047
1048class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):1049def __init__(self, config: ChatGLMConfig, empty_init=True):1050super().__init__(config)1051if empty_init:1052init_method = skip_init1053else:1054init_method = default_init1055
1056# self.hidden_size = config.hidden_size1057# self.params_dtype = torch.half1058# self.vocab_size = config.vocab_size1059self.max_sequence_length = config.max_sequence_length1060
1061self.position_encoding_2d = config.position_encoding_2d1062
1063self.transformer = ChatGLMModel(config, empty_init=empty_init)1064
1065self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)1066
1067self.config = config1068
1069self.quantized = False1070
1071if self.config.quantization_bit:1072self.quantize(self.config.quantization_bit, empty_init=True)1073
1074def get_output_embeddings(self):1075return self.lm_head1076
1077def set_output_embeddings(self, new_embeddings):1078self.lm_head = new_embeddings1079
1080def _update_model_kwargs_for_generation(1081self,1082outputs: ModelOutput,1083model_kwargs: Dict[str, Any],1084is_encoder_decoder: bool = False,1085standardize_cache_format: bool = False,1086) -> Dict[str, Any]:1087# update past_key_values1088model_kwargs["past_key_values"] = self._extract_past_from_model_output(1089outputs, standardize_cache_format=standardize_cache_format1090)1091
1092# update attention mask1093if "attention_mask" in model_kwargs:1094attention_mask = model_kwargs["attention_mask"]1095if attention_mask is not None and attention_mask.dtype == torch.bool:1096attention_mask = torch.cat(1097[attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=31098)1099new_attention_mask = attention_mask[:, :, -1:].clone()1100new_attention_mask[..., -1] = False1101model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)1102
1103# update position ids1104if "position_ids" in model_kwargs:1105position_ids = model_kwargs["position_ids"]1106new_position_id = position_ids[..., -1:].clone()1107new_position_id[:, 1, :] += 11108model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)1109
1110return model_kwargs1111
1112def prepare_inputs_for_generation(1113self,1114input_ids: torch.LongTensor,1115past: Optional[torch.Tensor] = None,1116past_key_values: Optional[torch.Tensor] = None,1117attention_mask: Optional[torch.Tensor] = None,1118position_ids: Optional[torch.Tensor] = None,1119**kwargs,1120) -> dict:1121batch_size, seq_length = input_ids.shape1122MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id1123seqs = input_ids.tolist()1124mask_positions, use_gmasks = [], []1125for seq in seqs:1126mask_token = gMASK if gMASK in seq else MASK1127use_gmask = mask_token == gMASK1128mask_positions.append(seq.index(mask_token))1129use_gmasks.append(use_gmask)1130
1131# only last token for input_ids if past is not None1132if past is not None or past_key_values is not None:1133last_token = input_ids[:, -1].unsqueeze(-1)1134if attention_mask is not None and attention_mask.dtype == torch.bool:1135attention_mask = attention_mask[:, :, -1:]1136else:1137attention_mask = None1138if position_ids is not None:1139position_ids = position_ids[..., -1:]1140else:1141context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]1142if self.position_encoding_2d:1143position_ids = torch.tensor(1144[1145[mask_position, seq_length - context_length]1146for mask_position, context_length in zip(mask_positions, context_lengths)1147],1148dtype=torch.long,1149device=input_ids.device,1150).unsqueeze(-1)1151else:1152position_ids = torch.tensor(1153[mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device1154).unsqueeze(-1)1155
1156if past is None:1157past = past_key_values1158return {1159"input_ids": last_token,1160"past_key_values": past,1161"position_ids": position_ids,1162"attention_mask": attention_mask,1163}1164else:1165if attention_mask is not None and attention_mask.dtype != torch.bool:1166logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")1167attention_mask = None1168if attention_mask is None:1169attention_mask = self.get_masks(input_ids, device=input_ids.device)1170if position_ids is None:1171position_ids = self.get_position_ids(1172input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks1173)1174
1175return {1176"input_ids": input_ids,1177"past_key_values": past,1178"position_ids": position_ids,1179"attention_mask": attention_mask,1180}1181
1182def forward(1183self,1184input_ids: Optional[torch.Tensor] = None,1185position_ids: Optional[torch.Tensor] = None,1186attention_mask: Optional[torch.Tensor] = None,1187past_key_values: Optional[Tuple[torch.FloatTensor]] = None,1188inputs_embeds: Optional[torch.Tensor] = None,1189labels: Optional[torch.Tensor] = None,1190use_cache: Optional[bool] = None,1191output_attentions: Optional[bool] = None,1192output_hidden_states: Optional[bool] = None,1193return_dict: Optional[bool] = None,1194):1195use_cache = use_cache if use_cache is not None else self.config.use_cache1196return_dict = return_dict if return_dict is not None else self.config.use_return_dict1197
1198transformer_outputs = self.transformer(1199input_ids=input_ids,1200position_ids=position_ids,1201attention_mask=attention_mask,1202past_key_values=past_key_values,1203inputs_embeds=inputs_embeds,1204use_cache=use_cache,1205output_attentions=output_attentions,1206output_hidden_states=output_hidden_states,1207return_dict=return_dict,1208)1209
1210hidden_states = transformer_outputs[0]1211
1212lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()1213
1214loss = None1215if labels is not None:1216lm_logits = lm_logits.to(torch.float32)1217
1218# Shift so that tokens < n predict n1219shift_logits = lm_logits[..., :-1, :].contiguous()1220shift_labels = labels[..., 1:].contiguous()1221# Flatten the tokens1222loss_fct = CrossEntropyLoss(ignore_index=-100)1223loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))1224
1225lm_logits = lm_logits.to(hidden_states.dtype)1226loss = loss.to(hidden_states.dtype)1227
1228if not return_dict:1229output = (lm_logits,) + transformer_outputs[1:]1230return ((loss,) + output) if loss is not None else output1231
1232return CausalLMOutputWithPast(1233loss=loss,1234logits=lm_logits,1235past_key_values=transformer_outputs.past_key_values,1236hidden_states=transformer_outputs.hidden_states,1237attentions=transformer_outputs.attentions,1238)1239
1240@staticmethod1241def _reorder_cache(1242past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor1243) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:1244"""1245This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1246[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1247beam_idx at every generation step.
1248
1249Output shares the same memory storage as `past`.
1250"""
1251return tuple(1252(1253layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),1254layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),1255)1256for layer_past in past1257)1258
1259def process_response(self, response):1260response = response.strip()1261response = response.replace("[[训练时间]]", "2023年")1262punkts = [1263[",", ","],1264["!", "!"],1265[":", ":"],1266[";", ";"],1267["\?", "?"],1268]1269for item in punkts:1270response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)1271response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)1272return response1273
1274@torch.no_grad()1275def chat(1276self,1277tokenizer,1278query: str,1279history: List[Tuple[str, str]] = None,1280max_length: int = 2048,1281num_beams=1,1282do_sample=True,1283top_p=0.7,1284temperature=0.95,1285logits_processor=None,1286**kwargs,1287):1288if history is None:1289history = []1290if logits_processor is None:1291logits_processor = LogitsProcessorList()1292logits_processor.append(InvalidScoreLogitsProcessor())1293gen_kwargs = {1294"max_length": max_length,1295"num_beams": num_beams,1296"do_sample": do_sample,1297"top_p": top_p,1298"temperature": temperature,1299"logits_processor": logits_processor,1300**kwargs,1301}1302if not history:1303prompt = query1304else:1305prompt = ""1306for i, (old_query, response) in enumerate(history):1307prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)1308prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)1309inputs = tokenizer([prompt], return_tensors="pt")1310inputs = inputs.to(self.device)1311outputs = self.generate(**inputs, **gen_kwargs)1312outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]1313response = tokenizer.decode(outputs)1314response = self.process_response(response)1315history = history + [(query, response)]1316return response, history1317
1318@torch.no_grad()1319def stream_chat(1320self,1321tokenizer,1322query: str,1323history: List[Tuple[str, str]] = None,1324max_length: int = 2048,1325do_sample=True,1326top_p=0.7,1327temperature=0.95,1328logits_processor=None,1329**kwargs,1330):1331if history is None:1332history = []1333if logits_processor is None:1334logits_processor = LogitsProcessorList()1335logits_processor.append(InvalidScoreLogitsProcessor())1336gen_kwargs = {1337"max_length": max_length,1338"do_sample": do_sample,1339"top_p": top_p,1340"temperature": temperature,1341"logits_processor": logits_processor,1342**kwargs,1343}1344if not history:1345prompt = query1346else:1347prompt = ""1348for i, (old_query, response) in enumerate(history):1349prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)1350prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)1351inputs = tokenizer([prompt], return_tensors="pt")1352inputs = inputs.to(self.device)1353for outputs in self.stream_generate(**inputs, **gen_kwargs):1354outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]1355response = tokenizer.decode(outputs)1356response = self.process_response(response)1357new_history = history + [(query, response)]1358yield response, new_history1359
1360@torch.no_grad()1361def stream_generate(1362self,1363input_ids,1364generation_config: Optional[GenerationConfig] = None,1365logits_processor: Optional[LogitsProcessorList] = None,1366stopping_criteria: Optional[StoppingCriteriaList] = None,1367prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,1368**kwargs,1369):1370batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]1371
1372if generation_config is None:1373generation_config = self.generation_config1374generation_config = copy.deepcopy(generation_config)1375model_kwargs = generation_config.update(**kwargs)1376bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id1377
1378if isinstance(eos_token_id, int):1379eos_token_id = [eos_token_id]1380
1381has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None1382if has_default_max_length and generation_config.max_new_tokens is None:1383warnings.warn(1384f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "1385"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"1386" recommend using `max_new_tokens` to control the maximum length of the generation.",1387UserWarning,1388)1389elif generation_config.max_new_tokens is not None:1390generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length1391if not has_default_max_length:1392logger.warn(1393f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="1394f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "1395"Please refer to the documentation for more information. "1396"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",1397UserWarning,1398)1399
1400if input_ids_seq_length >= generation_config.max_length:1401input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"1402logger.warning(1403f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"1404f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"1405" increasing `max_new_tokens`."1406)1407
1408# 2. Set generation parameters if not already defined1409logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()1410stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()1411
1412logits_processor = self._get_logits_processor(1413generation_config=generation_config,1414input_ids_seq_length=input_ids_seq_length,1415encoder_input_ids=input_ids,1416prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,1417logits_processor=logits_processor,1418)1419
1420stopping_criteria = self._get_stopping_criteria(1421generation_config=generation_config, stopping_criteria=stopping_criteria1422)1423logits_warper = self._get_logits_warper(generation_config)1424
1425unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)1426scores = None1427while True:1428model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)1429# forward pass to get next token1430outputs = self(1431**model_inputs,1432return_dict=True,1433output_attentions=False,1434output_hidden_states=False,1435)1436
1437next_token_logits = outputs.logits[:, -1, :]1438
1439# pre-process distribution1440next_token_scores = logits_processor(input_ids, next_token_logits)1441next_token_scores = logits_warper(input_ids, next_token_scores)1442
1443# sample1444probs = nn.functional.softmax(next_token_scores, dim=-1)1445if generation_config.do_sample:1446next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)1447else:1448next_tokens = torch.argmax(probs, dim=-1)1449
1450# update generated ids, model inputs, and length for next step1451input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)1452model_kwargs = self._update_model_kwargs_for_generation(1453outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder1454)1455unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())1456
1457# stop when each sentence is finished, or if we exceed the maximum length1458if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):1459break1460yield input_ids1461
1462def quantize(self, bits: int, empty_init=False, **kwargs):1463if bits == 0:1464return1465
1466from .quantization import quantize1467
1468if self.quantized:1469logger.info("Already quantized.")1470return self1471
1472self.quantized = True1473
1474self.config.quantization_bit = bits1475
1476self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)1477return self1478