colossalai

Форк
0
1477 строк · 56.2 Кб
1
"""
2
This code is copied from https://huggingface.co/THUDM/chatglm-6b/resolve/main/modeling_chatglm.py
3
"""
4

5
""" PyTorch ChatGLM model. """
6

7
import copy
8
import math
9
import os
10
import re
11
import sys
12
import warnings
13
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
14

15
import torch
16
import torch.nn.functional as F
17
import torch.utils.checkpoint
18
from torch import nn
19
from torch.nn import CrossEntropyLoss, LayerNorm
20
from torch.nn.utils import skip_init
21
from transformers.generation.logits_process import LogitsProcessor
22
from transformers.generation.utils import GenerationConfig, LogitsProcessorList, ModelOutput, StoppingCriteriaList
23
from transformers.modeling_outputs import (
24
    BaseModelOutputWithPast,
25
    BaseModelOutputWithPastAndCrossAttentions,
26
    CausalLMOutputWithPast,
27
)
28
from transformers.modeling_utils import PreTrainedModel
29
from transformers.utils import (
30
    add_code_sample_docstrings,
31
    add_start_docstrings,
32
    add_start_docstrings_to_model_forward,
33
    logging,
34
)
35

36
from .configuration_chatglm import ChatGLMConfig
37

38
# flags required to enable jit fusion kernels
39

40
if sys.platform != "darwin":
41
    torch._C._jit_set_profiling_mode(False)
42
    torch._C._jit_set_profiling_executor(False)
43
    torch._C._jit_override_can_fuse_on_cpu(True)
44
    torch._C._jit_override_can_fuse_on_gpu(True)
45

46
logger = logging.get_logger(__name__)
47

48
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM-6B"
49
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
50

51
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
52
    "THUDM/chatglm-6b",
53
    # See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
54
]
55

56

57
class InvalidScoreLogitsProcessor(LogitsProcessor):
58
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
59
        if torch.isnan(scores).any() or torch.isinf(scores).any():
60
            scores.zero_()
61
            scores[..., 5] = 5e4
62
        return scores
63

64

65
def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path):
66
    """Load tf checkpoints in a pytorch model."""
67
    try:
68
        import re
69

70
        import numpy as np
71
        import tensorflow as tf
72
    except ImportError:
73
        logger.error(
74
            "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
75
            "https://www.tensorflow.org/install/ for installation instructions."
76
        )
77
        raise
78
    tf_path = os.path.abspath(tf_checkpoint_path)
79
    logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
80
    # Load weights from TF model
81
    init_vars = tf.train.list_variables(tf_path)
82
    names = []
83
    arrays = []
84
    for name, shape in init_vars:
85
        logger.info(f"Loading TF weight {name} with shape {shape}")
86
        array = tf.train.load_variable(tf_path, name)
87
        names.append(name)
88
        arrays.append(array)
89

90
    for name, array in zip(names, arrays):
91
        name = name.split("/")
92
        # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
93
        # which are not required for using pretrained model
94
        if any(
95
            n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
96
            for n in name
97
        ):
98
            logger.info(f"Skipping {'/'.join(name)}")
99
            continue
100
        pointer = model
101
        for m_name in name:
102
            if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
103
                scope_names = re.split(r"_(\d+)", m_name)
104
            else:
105
                scope_names = [m_name]
106
            if scope_names[0] == "kernel" or scope_names[0] == "gamma":
107
                pointer = getattr(pointer, "weight")
108
            elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
109
                pointer = getattr(pointer, "bias")
110
            elif scope_names[0] == "output_weights":
111
                pointer = getattr(pointer, "weight")
112
            elif scope_names[0] == "squad":
113
                pointer = getattr(pointer, "classifier")
114
            else:
115
                try:
116
                    pointer = getattr(pointer, scope_names[0])
117
                except AttributeError:
118
                    logger.info(f"Skipping {'/'.join(name)}")
119
                    continue
120
            if len(scope_names) >= 2:
121
                num = int(scope_names[1])
122
                pointer = pointer[num]
123
        if m_name[-11:] == "_embeddings":
124
            pointer = getattr(pointer, "weight")
125
        elif m_name == "kernel":
126
            array = np.transpose(array)
127
        try:
128
            assert (
129
                pointer.shape == array.shape
130
            ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
131
        except AssertionError as e:
132
            e.args += (pointer.shape, array.shape)
133
            raise
134
        logger.info(f"Initialize PyTorch weight {name}")
135
        pointer.data = torch.from_numpy(array)
136
    return model
137

138

139
class PrefixEncoder(torch.nn.Module):
140
    """
141
    The torch.nn model to encode the prefix
142
    Input shape: (batch-size, prefix-length)
143
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
144
    """
145

146
    def __init__(self, config):
147
        super().__init__()
148
        self.prefix_projection = config.prefix_projection
149
        if self.prefix_projection:
150
            # Use a two-layer MLP to encode the prefix
151
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
152
            self.trans = torch.nn.Sequential(
153
                torch.nn.Linear(config.hidden_size, config.hidden_size),
154
                torch.nn.Tanh(),
155
                torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2),
156
            )
157
        else:
158
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)
159

160
    def forward(self, prefix: torch.Tensor):
161
        if self.prefix_projection:
162
            prefix_tokens = self.embedding(prefix)
163
            past_key_values = self.trans(prefix_tokens)
164
        else:
165
            past_key_values = self.embedding(prefix)
166
        return past_key_values
167

168

169
@torch.jit.script
170
def gelu_impl(x):
171
    """OpenAI's gelu implementation."""
172
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
173

174

175
def gelu(x):
176
    return gelu_impl(x)
177

178

179
class RotaryEmbedding(torch.nn.Module):
180
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
181
        super().__init__()
182
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
183
        inv_freq = inv_freq.half()
184
        self.learnable = learnable
185
        if learnable:
186
            self.inv_freq = torch.nn.Parameter(inv_freq)
187
            self.max_seq_len_cached = None
188
        else:
189
            self.register_buffer("inv_freq", inv_freq)
190
            self.max_seq_len_cached = None
191
            self.cos_cached = None
192
            self.sin_cached = None
193
        self.precision = precision
194

195
    def _load_from_state_dict(
196
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
197
    ):
198
        pass
199

200
    def forward(self, x, seq_dim=1, seq_len=None):
201
        if seq_len is None:
202
            seq_len = x.shape[seq_dim]
203
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
204
            self.max_seq_len_cached = None if self.learnable else seq_len
205
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
206
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
207
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
208
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
209
            if self.precision == torch.bfloat16:
210
                emb = emb.float()
211

212
            # [sx, 1 (b * np), hn]
213
            cos_cached = emb.cos()[:, None, :]
214
            sin_cached = emb.sin()[:, None, :]
215
            if self.precision == torch.bfloat16:
216
                cos_cached = cos_cached.bfloat16()
217
                sin_cached = sin_cached.bfloat16()
218
            if self.learnable:
219
                return cos_cached, sin_cached
220
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
221
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
222

223
    def _apply(self, fn):
224
        if self.cos_cached is not None:
225
            self.cos_cached = fn(self.cos_cached)
226
        if self.sin_cached is not None:
227
            self.sin_cached = fn(self.sin_cached)
228
        return super()._apply(fn)
229

230

231
def rotate_half(x):
232
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
233
    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in earlier torch versions
234

235

236
@torch.jit.script
237
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
238
    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
239
    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
240
        position_id, sin.squeeze(1)
241
    ).unsqueeze(2)
242
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
243
    return q, k
244

245

246
def attention_fn(
247
    self,
248
    query_layer,
249
    key_layer,
250
    value_layer,
251
    attention_mask,
252
    hidden_size_per_partition,
253
    layer_id,
254
    layer_past=None,
255
    scaling_attention_score=True,
256
    use_cache=False,
257
):
258
    if layer_past is not None:
259
        past_key, past_value = layer_past[0], layer_past[1]
260
        key_layer = torch.cat((past_key, key_layer), dim=0)
261
        value_layer = torch.cat((past_value, value_layer), dim=0)
262

263
    # seqlen, batch, num_attention_heads, hidden_size_per_attention_head
264
    seq_len, b, nh, hidden_size = key_layer.shape
265

266
    if use_cache:
267
        present = (key_layer, value_layer)
268
    else:
269
        present = None
270

271
    query_key_layer_scaling_coeff = float(layer_id + 1)
272
    if scaling_attention_score:
273
        query_layer = query_layer / (math.sqrt(hidden_size) * query_key_layer_scaling_coeff)
274

275
    # ===================================
276
    # Raw attention scores. [b, np, s, s]
277
    # ===================================
278

279
    # [b, np, sq, sk]
280
    output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))
281

282
    # [sq, b, np, hn] -> [sq, b * np, hn]
283
    query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
284
    # [sk, b, np, hn] -> [sk, b * np, hn]
285
    key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
286

287
    matmul_result = torch.zeros(
288
        1,
289
        1,
290
        1,
291
        dtype=query_layer.dtype,
292
        device=query_layer.device,
293
    )
294

295
    matmul_result = torch.baddbmm(
296
        matmul_result,
297
        query_layer.transpose(0, 1),  # [b * np, sq, hn]
298
        key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
299
        beta=0.0,
300
        alpha=1.0,
301
    )
302

303
    # change view to [b, np, sq, sk]
304
    attention_scores = matmul_result.view(*output_size)
305

306
    if self.scale_mask_softmax:
307
        self.scale_mask_softmax.scale = query_key_layer_scaling_coeff
308
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask.contiguous())
309
    else:
310
        if not (attention_mask == 0).all():
311
            # if auto-regressive, skip
312
            attention_scores.masked_fill_(attention_mask, -10000.0)
313
        dtype = attention_scores.dtype
314
        attention_scores = attention_scores.float()
315
        attention_scores = attention_scores * query_key_layer_scaling_coeff
316

317
        attention_probs = F.softmax(attention_scores, dim=-1)
318

319
        attention_probs = attention_probs.type(dtype)
320

321
    # =========================
322
    # Context layer. [sq, b, hp]
323
    # =========================
324

325
    # value_layer -> context layer.
326
    # [sk, b, np, hn] --> [b, np, sq, hn]
327

328
    # context layer shape: [b, np, sq, hn]
329
    output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
330

331
    # change view [sk, b * np, hn]
332
    value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
333

334
    # change view [b * np, sq, sk]
335
    attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
336

337
    # matmul: [b * np, sq, hn]
338
    context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
339

340
    # change view [b, np, sq, hn]
341
    context_layer = context_layer.view(*output_size)
342

343
    # [b, np, sq, hn] --> [sq, b, np, hn]
344
    context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
345

346
    # [sq, b, np, hn] --> [sq, b, hp]
347
    new_context_layer_shape = context_layer.size()[:-2] + (hidden_size_per_partition,)
348
    context_layer = context_layer.view(*new_context_layer_shape)
349

350
    outputs = (context_layer, present, attention_probs)
351

352
    return outputs
353

354

355
def default_init(cls, *args, **kwargs):
356
    return cls(*args, **kwargs)
357

358

359
class SelfAttention(torch.nn.Module):
360
    def __init__(
361
        self,
362
        hidden_size,
363
        num_attention_heads,
364
        layer_id,
365
        hidden_size_per_attention_head=None,
366
        bias=True,
367
        params_dtype=torch.float,
368
        position_encoding_2d=True,
369
        empty_init=True,
370
    ):
371
        if empty_init:
372
            init_method = skip_init
373
        else:
374
            init_method = default_init
375
        super(SelfAttention, self).__init__()
376

377
        self.layer_id = layer_id
378
        self.hidden_size = hidden_size
379
        self.hidden_size_per_partition = hidden_size
380
        self.num_attention_heads = num_attention_heads
381
        self.num_attention_heads_per_partition = num_attention_heads
382
        self.position_encoding_2d = position_encoding_2d
383
        self.rotary_emb = RotaryEmbedding(
384
            self.hidden_size // (self.num_attention_heads * 2)
385
            if position_encoding_2d
386
            else self.hidden_size // self.num_attention_heads,
387
            base=10000,
388
            precision=torch.half,
389
            learnable=False,
390
        )
391

392
        self.scale_mask_softmax = None
393

394
        if hidden_size_per_attention_head is None:
395
            self.hidden_size_per_attention_head = hidden_size // num_attention_heads
396
        else:
397
            self.hidden_size_per_attention_head = hidden_size_per_attention_head
398

399
        self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
400

401
        # Strided linear layer.
402
        self.query_key_value = init_method(
403
            torch.nn.Linear,
404
            hidden_size,
405
            3 * self.inner_hidden_size,
406
            bias=bias,
407
            dtype=params_dtype,
408
        )
409

410
        self.dense = init_method(
411
            torch.nn.Linear,
412
            self.inner_hidden_size,
413
            hidden_size,
414
            bias=bias,
415
            dtype=params_dtype,
416
        )
417

418
    @staticmethod
419
    def attention_mask_func(attention_scores, attention_mask):
420
        attention_scores.masked_fill_(attention_mask, -10000.0)
421
        return attention_scores
422

423
    def split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_chunks=False):
424
        """Split a tensor along its last dimension.
425
        Arguments:
426
            tensor: input tensor.
427
            num_partitions: number of partitions to split the tensor
428
            contiguous_split_chunks: If True, make each chunk contiguous
429
                                    in memory.
430
        """
431
        # Get the size and dimension.
432
        last_dim = tensor.dim() - 1
433
        last_dim_size = tensor.size()[last_dim] // num_partitions
434
        # Split.
435
        tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
436
        # Note: torch.split does not create contiguous tensors by default.
437
        if contiguous_split_chunks:
438
            return tuple(chunk.contiguous() for chunk in tensor_list)
439

440
        return tensor_list
441

442
    def forward(
443
        self,
444
        hidden_states: torch.Tensor,
445
        position_ids,
446
        attention_mask: torch.Tensor,
447
        layer_id,
448
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
449
        use_cache: bool = False,
450
        output_attentions: bool = False,
451
    ):
452
        """
453
        hidden_states: [seq_len, batch, hidden_size]
454
        attention_mask: [(1, 1), seq_len, seq_len]
455
        """
456

457
        # [seq_len, batch, 3 * hidden_size]
458
        mixed_raw_layer = self.query_key_value(hidden_states)
459

460
        # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3 * hidden_size_per_attention_head]
461
        new_tensor_shape = mixed_raw_layer.size()[:-1] + (
462
            self.num_attention_heads_per_partition,
463
            3 * self.hidden_size_per_attention_head,
464
        )
465
        mixed_raw_layer = mixed_raw_layer.view(*new_tensor_shape)
466

467
        # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
468
        (query_layer, key_layer, value_layer) = self.split_tensor_along_last_dim(mixed_raw_layer, 3)
469

470
        if self.position_encoding_2d:
471
            q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1))
472
            k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
473
            cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1)
474
            position_ids, block_position_ids = (
475
                position_ids[:, 0, :].transpose(0, 1).contiguous(),
476
                position_ids[:, 1, :].transpose(0, 1).contiguous(),
477
            )
478
            q1, k1 = apply_rotary_pos_emb_index(q1, k1, cos, sin, position_ids)
479
            q2, k2 = apply_rotary_pos_emb_index(q2, k2, cos, sin, block_position_ids)
480
            query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
481
            key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
482
        else:
483
            position_ids = position_ids.transpose(0, 1)
484
            cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
485
            # [seq_len, batch, num_attention_heads, hidden_size_per_attention_head]
486
            query_layer, key_layer = apply_rotary_pos_emb_index(query_layer, key_layer, cos, sin, position_ids)
487

488
        # [seq_len, batch, hidden_size]
489
        context_layer, present, attention_probs = attention_fn(
490
            self=self,
491
            query_layer=query_layer,
492
            key_layer=key_layer,
493
            value_layer=value_layer,
494
            attention_mask=attention_mask,
495
            hidden_size_per_partition=self.hidden_size_per_partition,
496
            layer_id=layer_id,
497
            layer_past=layer_past,
498
            use_cache=use_cache,
499
        )
500

501
        output = self.dense(context_layer)
502

503
        outputs = (output, present)
504

505
        if output_attentions:
506
            outputs += (attention_probs,)
507

508
        return outputs  # output, present, attention_probs
509

510

511
class GEGLU(torch.nn.Module):
512
    def __init__(self):
513
        super().__init__()
514
        self.activation_fn = F.gelu
515

516
    def forward(self, x):
517
        # dim=-1 breaks in jit for pt<1.10
518
        x1, x2 = x.chunk(2, dim=(x.ndim - 1))
519
        return x1 * self.activation_fn(x2)
520

521

522
class GLU(torch.nn.Module):
523
    def __init__(
524
        self,
525
        hidden_size,
526
        inner_hidden_size=None,
527
        layer_id=None,
528
        bias=True,
529
        activation_func=gelu,
530
        params_dtype=torch.float,
531
        empty_init=True,
532
    ):
533
        super(GLU, self).__init__()
534
        if empty_init:
535
            init_method = skip_init
536
        else:
537
            init_method = default_init
538
        self.layer_id = layer_id
539
        self.activation_func = activation_func
540

541
        # Project to 4h.
542
        self.hidden_size = hidden_size
543
        if inner_hidden_size is None:
544
            inner_hidden_size = 4 * hidden_size
545
        self.inner_hidden_size = inner_hidden_size
546
        self.dense_h_to_4h = init_method(
547
            torch.nn.Linear,
548
            self.hidden_size,
549
            self.inner_hidden_size,
550
            bias=bias,
551
            dtype=params_dtype,
552
        )
553
        # Project back to h.
554
        self.dense_4h_to_h = init_method(
555
            torch.nn.Linear,
556
            self.inner_hidden_size,
557
            self.hidden_size,
558
            bias=bias,
559
            dtype=params_dtype,
560
        )
561

562
    def forward(self, hidden_states):
563
        """
564
        hidden_states: [seq_len, batch, hidden_size]
565
        """
566

567
        # [seq_len, batch, inner_hidden_size]
568
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
569

570
        intermediate_parallel = self.activation_func(intermediate_parallel)
571

572
        output = self.dense_4h_to_h(intermediate_parallel)
573

574
        return output
575

576

577
class GLMBlock(torch.nn.Module):
578
    def __init__(
579
        self,
580
        hidden_size,
581
        num_attention_heads,
582
        layernorm_epsilon,
583
        layer_id,
584
        inner_hidden_size=None,
585
        hidden_size_per_attention_head=None,
586
        layernorm=LayerNorm,
587
        use_bias=True,
588
        params_dtype=torch.float,
589
        num_layers=28,
590
        position_encoding_2d=True,
591
        empty_init=True,
592
    ):
593
        super(GLMBlock, self).__init__()
594
        # Set output layer initialization if not provided.
595

596
        self.layer_id = layer_id
597

598
        # Layernorm on the input data.
599
        self.input_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
600

601
        self.position_encoding_2d = position_encoding_2d
602

603
        # Self attention.
604
        self.attention = SelfAttention(
605
            hidden_size,
606
            num_attention_heads,
607
            layer_id,
608
            hidden_size_per_attention_head=hidden_size_per_attention_head,
609
            bias=use_bias,
610
            params_dtype=params_dtype,
611
            position_encoding_2d=self.position_encoding_2d,
612
            empty_init=empty_init,
613
        )
614

615
        # Layernorm on the input data.
616
        self.post_attention_layernorm = layernorm(hidden_size, eps=layernorm_epsilon)
617

618
        self.num_layers = num_layers
619

620
        # GLU
621
        self.mlp = GLU(
622
            hidden_size,
623
            inner_hidden_size=inner_hidden_size,
624
            bias=use_bias,
625
            layer_id=layer_id,
626
            params_dtype=params_dtype,
627
            empty_init=empty_init,
628
        )
629

630
    def forward(
631
        self,
632
        hidden_states: torch.Tensor,
633
        position_ids,
634
        attention_mask: torch.Tensor,
635
        layer_id,
636
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
637
        use_cache: bool = False,
638
        output_attentions: bool = False,
639
    ):
640
        """
641
        hidden_states: [seq_len, batch, hidden_size]
642
        attention_mask: [(1, 1), seq_len, seq_len]
643
        """
644

645
        # Layer norm at the begining of the transformer layer.
646
        # [seq_len, batch, hidden_size]
647
        attention_input = self.input_layernorm(hidden_states)
648

649
        # Self attention.
650
        attention_outputs = self.attention(
651
            attention_input,
652
            position_ids,
653
            attention_mask=attention_mask,
654
            layer_id=layer_id,
655
            layer_past=layer_past,
656
            use_cache=use_cache,
657
            output_attentions=output_attentions,
658
        )
659

660
        attention_output = attention_outputs[0]
661

662
        outputs = attention_outputs[1:]
663

664
        # Residual connection.
665
        alpha = (2 * self.num_layers) ** 0.5
666
        hidden_states = attention_input * alpha + attention_output
667

668
        mlp_input = self.post_attention_layernorm(hidden_states)
669

670
        # MLP.
671
        mlp_output = self.mlp(mlp_input)
672

673
        # Second residual connection.
674
        output = mlp_input * alpha + mlp_output
675

676
        if use_cache:
677
            outputs = (output,) + outputs
678
        else:
679
            outputs = (output,) + outputs[1:]
680

681
        return outputs  # hidden_states, present, attentions
682

683

684
class ChatGLMPreTrainedModel(PreTrainedModel):
685
    """
686
    An abstract class to handle weights initialization and
687
    a simple interface for downloading and loading pretrained models.
688
    """
689

690
    is_parallelizable = False
691
    supports_gradient_checkpointing = True
692
    config_class = ChatGLMConfig
693
    base_model_prefix = "transformer"
694
    _no_split_modules = ["GLMBlock"]
695

696
    def __init__(self, *inputs, **kwargs):
697
        super().__init__(*inputs, **kwargs)
698

699
    def _init_weights(self, module: nn.Module):
700
        """Initialize the weights."""
701
        return
702

703
    def get_masks(self, input_ids, device):
704
        batch_size, seq_length = input_ids.shape
705
        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
706
        attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
707
        attention_mask.tril_()
708
        for i, context_length in enumerate(context_lengths):
709
            attention_mask[i, :, :context_length] = 1
710
        attention_mask.unsqueeze_(1)
711
        attention_mask = (attention_mask < 0.5).bool()
712

713
        return attention_mask
714

715
    def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
716
        batch_size, seq_length = input_ids.shape
717
        if use_gmasks is None:
718
            use_gmasks = [False] * batch_size
719
        context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
720
        if self.position_encoding_2d:
721
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
722
            for i, context_length in enumerate(context_lengths):
723
                position_ids[i, context_length:] = mask_positions[i]
724
            block_position_ids = [
725
                torch.cat(
726
                    (
727
                        torch.zeros(context_length, dtype=torch.long, device=device),
728
                        torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1,
729
                    )
730
                )
731
                for context_length in context_lengths
732
            ]
733
            block_position_ids = torch.stack(block_position_ids, dim=0)
734
            position_ids = torch.stack((position_ids, block_position_ids), dim=1)
735
        else:
736
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
737
            for i, context_length in enumerate(context_lengths):
738
                if not use_gmasks[i]:
739
                    position_ids[i, context_length:] = mask_positions[i]
740

741
        return position_ids
742

743
    def _set_gradient_checkpointing(self, module, value=False):
744
        if isinstance(module, ChatGLMModel):
745
            module.gradient_checkpointing = value
746

747

748
CHATGLM_6B_START_DOCSTRING = r"""
749
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
750
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
751
    usage and behavior.
752

753
    Parameters:
754
        config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
755
            Initializing with a config file does not load the weights associated with the model, only the configuration.
756
            Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
757
"""
758

759
CHATGLM_6B_INPUTS_DOCSTRING = r"""
760
    Args:
761
        input_ids (`torch.LongTensor` of shape `({0})`):
762
            Indices of input sequence tokens in the vocabulary.
763

764
            Indices can be obtained using [`ChatGLM6BTokenizer`].
765
            See [`PreTrainedTokenizer.encode`] and
766
            [`PreTrainedTokenizer.__call__`] for details.
767

768
            [What are input IDs?](../glossary#input-ids)
769
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
770
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
771

772
            - 1 for tokens that are **not masked**,
773
            - 0 for tokens that are **masked**.
774

775
            [What are attention masks?](../glossary#attention-mask)
776
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
777
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
778

779
            - 0 corresponds to a *sentence A* token,
780
            - 1 corresponds to a *sentence B* token.
781

782
            [What are token type IDs?](../glossary#token-type-ids)
783
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
784
            Indices of positions of each input sequence tokens in the position embeddings.
785
            Selected in the range `[0, config.max_position_embeddings - 1]`.
786

787
            [What are position IDs?](../glossary#position-ids)
788
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
789
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
790

791
            - 1 indicates the head is **not masked**,
792
            - 0 indicates the head is **masked**.
793

794
        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
795
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
796
            This is useful if you want more control over how to convert *input_ids* indices into associated vectors
797
            than the model's internal embedding lookup matrix.
798
        output_attentions (`bool`, *optional*):
799
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
800
            tensors for more detail.
801
        output_hidden_states (`bool`, *optional*):
802
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
803
            more detail.
804
        return_dict (`bool`, *optional*):
805
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
806
"""
807

808

809
@add_start_docstrings(
810
    "The bare ChatGLM-6B Model transformer outputting raw hidden-states without any specific head on top.",
811
    CHATGLM_6B_START_DOCSTRING,
812
)
813
class ChatGLMModel(ChatGLMPreTrainedModel):
814
    """
815

816
    The model can behave as an encoder (with only self-attention) as well
817
    as a decoder, in which case a layer of cross-attention is added between
818
    the self-attention layers, following the architecture described in [Attention is
819
    all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
820
    Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
821

822
    To behave as an decoder the model needs to be initialized with the
823
    `is_decoder` argument of the configuration set to `True`.
824
    To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
825
    argument and `add_cross_attention` set to `True`; an
826
    `encoder_hidden_states` is then expected as an input to the forward pass.
827
    """
828

829
    def __init__(self, config: ChatGLMConfig, empty_init=True):
830
        super().__init__(config)
831
        if empty_init:
832
            init_method = skip_init
833
        else:
834
            init_method = default_init
835
        # recording parameters
836
        self.max_sequence_length = config.max_sequence_length
837
        self.hidden_size = config.hidden_size
838
        self.params_dtype = torch.half
839
        self.num_attention_heads = config.num_attention_heads
840
        self.vocab_size = config.vocab_size
841
        self.num_layers = config.num_layers
842
        self.layernorm_epsilon = config.layernorm_epsilon
843
        self.inner_hidden_size = config.inner_hidden_size
844
        self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads
845
        self.position_encoding_2d = config.position_encoding_2d
846
        self.pre_seq_len = config.pre_seq_len
847
        self.prefix_projection = config.prefix_projection
848

849
        self.word_embeddings = init_method(
850
            torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype
851
        )
852
        self.gradient_checkpointing = False
853

854
        def get_layer(layer_id):
855
            return GLMBlock(
856
                self.hidden_size,
857
                self.num_attention_heads,
858
                self.layernorm_epsilon,
859
                layer_id,
860
                inner_hidden_size=self.inner_hidden_size,
861
                hidden_size_per_attention_head=self.hidden_size_per_attention_head,
862
                layernorm=LayerNorm,
863
                use_bias=True,
864
                params_dtype=self.params_dtype,
865
                position_encoding_2d=self.position_encoding_2d,
866
                empty_init=empty_init,
867
            )
868

869
        self.layers = torch.nn.ModuleList([get_layer(layer_id) for layer_id in range(self.num_layers)])
870

871
        # Final layer norm before output.
872
        self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon)
873

874
        if self.pre_seq_len is not None:
875
            for param in self.parameters():
876
                param.requires_grad = False
877
            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
878
            self.prefix_encoder = PrefixEncoder(config)
879
            self.dropout = torch.nn.Dropout(0.1)
880

881
            # total_params = sum(p.numel() for p in self.parameters())
882
            # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
883
            # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params))
884

885
    def get_input_embeddings(self):
886
        return self.word_embeddings
887

888
    def set_input_embeddings(self, new_embeddings: torch.Tensor):
889
        self.word_embeddings = new_embeddings
890

891
    def get_prompt(self, batch_size, device, dtype=torch.half):
892
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
893
        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
894
        past_key_values = past_key_values.view(
895
            batch_size,
896
            self.pre_seq_len,
897
            self.num_layers * 2,
898
            self.num_attention_heads,
899
            self.hidden_size // self.num_attention_heads,
900
        )
901
        # seq_len, b, nh, hidden_size
902
        past_key_values = self.dropout(past_key_values)
903
        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
904
        # past_key_values = [(v[0], v[1]) for v in past_key_values]
905
        return past_key_values
906

907
    @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
908
    @add_code_sample_docstrings(
909
        checkpoint=_CHECKPOINT_FOR_DOC,
910
        output_type=BaseModelOutputWithPastAndCrossAttentions,
911
        config_class=_CONFIG_FOR_DOC,
912
    )
913
    def forward(
914
        self,
915
        input_ids: Optional[torch.LongTensor] = None,
916
        position_ids: Optional[torch.LongTensor] = None,
917
        attention_mask: Optional[torch.Tensor] = None,
918
        past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
919
        inputs_embeds: Optional[torch.LongTensor] = None,
920
        use_cache: Optional[bool] = None,
921
        output_attentions: Optional[bool] = None,
922
        output_hidden_states: Optional[bool] = None,
923
        return_dict: Optional[bool] = None,
924
    ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
925
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
926
        output_hidden_states = (
927
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
928
        )
929
        use_cache = use_cache if use_cache is not None else self.config.use_cache
930
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
931

932
        if self.gradient_checkpointing and self.training:
933
            if use_cache:
934
                logger.warning_once(
935
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
936
                )
937
                use_cache = False
938

939
        if input_ids is not None and inputs_embeds is not None:
940
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
941
        elif input_ids is not None:
942
            batch_size, seq_length = input_ids.shape[:2]
943
        elif inputs_embeds is not None:
944
            batch_size, seq_length = inputs_embeds.shape[:2]
945
        else:
946
            raise ValueError("You have to specify either input_ids or inputs_embeds")
947

948
        if inputs_embeds is None:
949
            inputs_embeds = self.word_embeddings(input_ids)
950

951
        if past_key_values is None:
952
            if self.pre_seq_len is not None:
953
                past_key_values = self.get_prompt(
954
                    batch_size=input_ids.shape[0], device=input_ids.device, dtype=inputs_embeds.dtype
955
                )
956
            else:
957
                past_key_values = tuple([None] * len(self.layers))
958

959
            if attention_mask is None:
960
                attention_mask = self.get_masks(input_ids, device=input_ids.device)
961

962
            if position_ids is None:
963
                MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
964
                seqs = input_ids.tolist()
965

966
                mask_positions, use_gmasks = [], []
967
                for seq in seqs:
968
                    mask_token = gMASK if gMASK in seq else MASK
969
                    use_gmask = mask_token == gMASK
970
                    mask_positions.append(seq.index(mask_token))
971
                    use_gmasks.append(use_gmask)
972

973
                position_ids = self.get_position_ids(
974
                    input_ids, mask_positions=mask_positions, device=input_ids.device, use_gmasks=use_gmasks
975
                )
976

977
        if self.pre_seq_len is not None and attention_mask is not None:
978
            prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(
979
                attention_mask.device
980
            )
981
            prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
982
            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
983

984
        # [seq_len, batch, hidden_size]
985
        hidden_states = inputs_embeds.transpose(0, 1)
986

987
        presents = () if use_cache else None
988
        all_self_attentions = () if output_attentions else None
989
        all_hidden_states = () if output_hidden_states else None
990

991
        if attention_mask is None:
992
            attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
993
        else:
994
            attention_mask = attention_mask.to(hidden_states.device)
995

996
        for i, layer in enumerate(self.layers):
997
            if output_hidden_states:
998
                all_hidden_states = all_hidden_states + (hidden_states,)
999
            layer_past = past_key_values[i]
1000

1001
            if self.gradient_checkpointing and self.training:
1002
                layer_ret = torch.utils.checkpoint.checkpoint(
1003
                    layer,
1004
                    hidden_states,
1005
                    position_ids,
1006
                    attention_mask,
1007
                    torch.tensor(i),
1008
                    layer_past,
1009
                    use_cache,
1010
                    output_attentions,
1011
                )
1012
            else:
1013
                layer_ret = layer(
1014
                    hidden_states,
1015
                    position_ids=position_ids,
1016
                    attention_mask=attention_mask,
1017
                    layer_id=torch.tensor(i),
1018
                    layer_past=layer_past,
1019
                    use_cache=use_cache,
1020
                    output_attentions=output_attentions,
1021
                )
1022

1023
            hidden_states = layer_ret[0]
1024

1025
            if use_cache:
1026
                presents = presents + (layer_ret[1],)
1027

1028
            if output_attentions:
1029
                all_self_attentions = all_self_attentions + (layer_ret[2 if use_cache else 1],)
1030

1031
        # Final layer norm.
1032
        hidden_states = self.final_layernorm(hidden_states)
1033

1034
        if output_hidden_states:
1035
            all_hidden_states = all_hidden_states + (hidden_states,)
1036

1037
        if not return_dict:
1038
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
1039

1040
        return BaseModelOutputWithPast(
1041
            last_hidden_state=hidden_states,
1042
            past_key_values=presents,
1043
            hidden_states=all_hidden_states,
1044
            attentions=all_self_attentions,
1045
        )
1046

1047

1048
class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1049
    def __init__(self, config: ChatGLMConfig, empty_init=True):
1050
        super().__init__(config)
1051
        if empty_init:
1052
            init_method = skip_init
1053
        else:
1054
            init_method = default_init
1055

1056
        # self.hidden_size = config.hidden_size
1057
        # self.params_dtype = torch.half
1058
        # self.vocab_size = config.vocab_size
1059
        self.max_sequence_length = config.max_sequence_length
1060

1061
        self.position_encoding_2d = config.position_encoding_2d
1062

1063
        self.transformer = ChatGLMModel(config, empty_init=empty_init)
1064

1065
        self.lm_head = init_method(nn.Linear, config.hidden_size, config.vocab_size, bias=False, dtype=torch.half)
1066

1067
        self.config = config
1068

1069
        self.quantized = False
1070

1071
        if self.config.quantization_bit:
1072
            self.quantize(self.config.quantization_bit, empty_init=True)
1073

1074
    def get_output_embeddings(self):
1075
        return self.lm_head
1076

1077
    def set_output_embeddings(self, new_embeddings):
1078
        self.lm_head = new_embeddings
1079

1080
    def _update_model_kwargs_for_generation(
1081
        self,
1082
        outputs: ModelOutput,
1083
        model_kwargs: Dict[str, Any],
1084
        is_encoder_decoder: bool = False,
1085
        standardize_cache_format: bool = False,
1086
    ) -> Dict[str, Any]:
1087
        # update past_key_values
1088
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
1089
            outputs, standardize_cache_format=standardize_cache_format
1090
        )
1091

1092
        # update attention mask
1093
        if "attention_mask" in model_kwargs:
1094
            attention_mask = model_kwargs["attention_mask"]
1095
            if attention_mask is not None and attention_mask.dtype == torch.bool:
1096
                attention_mask = torch.cat(
1097
                    [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3
1098
                )
1099
                new_attention_mask = attention_mask[:, :, -1:].clone()
1100
                new_attention_mask[..., -1] = False
1101
                model_kwargs["attention_mask"] = torch.cat([attention_mask, new_attention_mask], dim=2)
1102

1103
        # update position ids
1104
        if "position_ids" in model_kwargs:
1105
            position_ids = model_kwargs["position_ids"]
1106
            new_position_id = position_ids[..., -1:].clone()
1107
            new_position_id[:, 1, :] += 1
1108
            model_kwargs["position_ids"] = torch.cat([position_ids, new_position_id], dim=-1)
1109

1110
        return model_kwargs
1111

1112
    def prepare_inputs_for_generation(
1113
        self,
1114
        input_ids: torch.LongTensor,
1115
        past: Optional[torch.Tensor] = None,
1116
        past_key_values: Optional[torch.Tensor] = None,
1117
        attention_mask: Optional[torch.Tensor] = None,
1118
        position_ids: Optional[torch.Tensor] = None,
1119
        **kwargs,
1120
    ) -> dict:
1121
        batch_size, seq_length = input_ids.shape
1122
        MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
1123
        seqs = input_ids.tolist()
1124
        mask_positions, use_gmasks = [], []
1125
        for seq in seqs:
1126
            mask_token = gMASK if gMASK in seq else MASK
1127
            use_gmask = mask_token == gMASK
1128
            mask_positions.append(seq.index(mask_token))
1129
            use_gmasks.append(use_gmask)
1130

1131
        # only last token for input_ids if past is not None
1132
        if past is not None or past_key_values is not None:
1133
            last_token = input_ids[:, -1].unsqueeze(-1)
1134
            if attention_mask is not None and attention_mask.dtype == torch.bool:
1135
                attention_mask = attention_mask[:, :, -1:]
1136
            else:
1137
                attention_mask = None
1138
            if position_ids is not None:
1139
                position_ids = position_ids[..., -1:]
1140
            else:
1141
                context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
1142
                if self.position_encoding_2d:
1143
                    position_ids = torch.tensor(
1144
                        [
1145
                            [mask_position, seq_length - context_length]
1146
                            for mask_position, context_length in zip(mask_positions, context_lengths)
1147
                        ],
1148
                        dtype=torch.long,
1149
                        device=input_ids.device,
1150
                    ).unsqueeze(-1)
1151
                else:
1152
                    position_ids = torch.tensor(
1153
                        [mask_position for mask_position in mask_positions], dtype=torch.long, device=input_ids.device
1154
                    ).unsqueeze(-1)
1155

1156
            if past is None:
1157
                past = past_key_values
1158
            return {
1159
                "input_ids": last_token,
1160
                "past_key_values": past,
1161
                "position_ids": position_ids,
1162
                "attention_mask": attention_mask,
1163
            }
1164
        else:
1165
            if attention_mask is not None and attention_mask.dtype != torch.bool:
1166
                logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool")
1167
                attention_mask = None
1168
            if attention_mask is None:
1169
                attention_mask = self.get_masks(input_ids, device=input_ids.device)
1170
            if position_ids is None:
1171
                position_ids = self.get_position_ids(
1172
                    input_ids, device=input_ids.device, mask_positions=mask_positions, use_gmasks=use_gmasks
1173
                )
1174

1175
            return {
1176
                "input_ids": input_ids,
1177
                "past_key_values": past,
1178
                "position_ids": position_ids,
1179
                "attention_mask": attention_mask,
1180
            }
1181

1182
    def forward(
1183
        self,
1184
        input_ids: Optional[torch.Tensor] = None,
1185
        position_ids: Optional[torch.Tensor] = None,
1186
        attention_mask: Optional[torch.Tensor] = None,
1187
        past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
1188
        inputs_embeds: Optional[torch.Tensor] = None,
1189
        labels: Optional[torch.Tensor] = None,
1190
        use_cache: Optional[bool] = None,
1191
        output_attentions: Optional[bool] = None,
1192
        output_hidden_states: Optional[bool] = None,
1193
        return_dict: Optional[bool] = None,
1194
    ):
1195
        use_cache = use_cache if use_cache is not None else self.config.use_cache
1196
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197

1198
        transformer_outputs = self.transformer(
1199
            input_ids=input_ids,
1200
            position_ids=position_ids,
1201
            attention_mask=attention_mask,
1202
            past_key_values=past_key_values,
1203
            inputs_embeds=inputs_embeds,
1204
            use_cache=use_cache,
1205
            output_attentions=output_attentions,
1206
            output_hidden_states=output_hidden_states,
1207
            return_dict=return_dict,
1208
        )
1209

1210
        hidden_states = transformer_outputs[0]
1211

1212
        lm_logits = self.lm_head(hidden_states).permute(1, 0, 2).contiguous()
1213

1214
        loss = None
1215
        if labels is not None:
1216
            lm_logits = lm_logits.to(torch.float32)
1217

1218
            # Shift so that tokens < n predict n
1219
            shift_logits = lm_logits[..., :-1, :].contiguous()
1220
            shift_labels = labels[..., 1:].contiguous()
1221
            # Flatten the tokens
1222
            loss_fct = CrossEntropyLoss(ignore_index=-100)
1223
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1224

1225
            lm_logits = lm_logits.to(hidden_states.dtype)
1226
            loss = loss.to(hidden_states.dtype)
1227

1228
        if not return_dict:
1229
            output = (lm_logits,) + transformer_outputs[1:]
1230
            return ((loss,) + output) if loss is not None else output
1231

1232
        return CausalLMOutputWithPast(
1233
            loss=loss,
1234
            logits=lm_logits,
1235
            past_key_values=transformer_outputs.past_key_values,
1236
            hidden_states=transformer_outputs.hidden_states,
1237
            attentions=transformer_outputs.attentions,
1238
        )
1239

1240
    @staticmethod
1241
    def _reorder_cache(
1242
        past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
1243
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
1244
        """
1245
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1246
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1247
        beam_idx at every generation step.
1248

1249
        Output shares the same memory storage as `past`.
1250
        """
1251
        return tuple(
1252
            (
1253
                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
1254
                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
1255
            )
1256
            for layer_past in past
1257
        )
1258

1259
    def process_response(self, response):
1260
        response = response.strip()
1261
        response = response.replace("[[训练时间]]", "2023年")
1262
        punkts = [
1263
            [",", ","],
1264
            ["!", "!"],
1265
            [":", ":"],
1266
            [";", ";"],
1267
            ["\?", "?"],
1268
        ]
1269
        for item in punkts:
1270
            response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
1271
            response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
1272
        return response
1273

1274
    @torch.no_grad()
1275
    def chat(
1276
        self,
1277
        tokenizer,
1278
        query: str,
1279
        history: List[Tuple[str, str]] = None,
1280
        max_length: int = 2048,
1281
        num_beams=1,
1282
        do_sample=True,
1283
        top_p=0.7,
1284
        temperature=0.95,
1285
        logits_processor=None,
1286
        **kwargs,
1287
    ):
1288
        if history is None:
1289
            history = []
1290
        if logits_processor is None:
1291
            logits_processor = LogitsProcessorList()
1292
        logits_processor.append(InvalidScoreLogitsProcessor())
1293
        gen_kwargs = {
1294
            "max_length": max_length,
1295
            "num_beams": num_beams,
1296
            "do_sample": do_sample,
1297
            "top_p": top_p,
1298
            "temperature": temperature,
1299
            "logits_processor": logits_processor,
1300
            **kwargs,
1301
        }
1302
        if not history:
1303
            prompt = query
1304
        else:
1305
            prompt = ""
1306
            for i, (old_query, response) in enumerate(history):
1307
                prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1308
            prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1309
        inputs = tokenizer([prompt], return_tensors="pt")
1310
        inputs = inputs.to(self.device)
1311
        outputs = self.generate(**inputs, **gen_kwargs)
1312
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
1313
        response = tokenizer.decode(outputs)
1314
        response = self.process_response(response)
1315
        history = history + [(query, response)]
1316
        return response, history
1317

1318
    @torch.no_grad()
1319
    def stream_chat(
1320
        self,
1321
        tokenizer,
1322
        query: str,
1323
        history: List[Tuple[str, str]] = None,
1324
        max_length: int = 2048,
1325
        do_sample=True,
1326
        top_p=0.7,
1327
        temperature=0.95,
1328
        logits_processor=None,
1329
        **kwargs,
1330
    ):
1331
        if history is None:
1332
            history = []
1333
        if logits_processor is None:
1334
            logits_processor = LogitsProcessorList()
1335
        logits_processor.append(InvalidScoreLogitsProcessor())
1336
        gen_kwargs = {
1337
            "max_length": max_length,
1338
            "do_sample": do_sample,
1339
            "top_p": top_p,
1340
            "temperature": temperature,
1341
            "logits_processor": logits_processor,
1342
            **kwargs,
1343
        }
1344
        if not history:
1345
            prompt = query
1346
        else:
1347
            prompt = ""
1348
            for i, (old_query, response) in enumerate(history):
1349
                prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
1350
            prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
1351
        inputs = tokenizer([prompt], return_tensors="pt")
1352
        inputs = inputs.to(self.device)
1353
        for outputs in self.stream_generate(**inputs, **gen_kwargs):
1354
            outputs = outputs.tolist()[0][len(inputs["input_ids"][0]) :]
1355
            response = tokenizer.decode(outputs)
1356
            response = self.process_response(response)
1357
            new_history = history + [(query, response)]
1358
            yield response, new_history
1359

1360
    @torch.no_grad()
1361
    def stream_generate(
1362
        self,
1363
        input_ids,
1364
        generation_config: Optional[GenerationConfig] = None,
1365
        logits_processor: Optional[LogitsProcessorList] = None,
1366
        stopping_criteria: Optional[StoppingCriteriaList] = None,
1367
        prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
1368
        **kwargs,
1369
    ):
1370
        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
1371

1372
        if generation_config is None:
1373
            generation_config = self.generation_config
1374
        generation_config = copy.deepcopy(generation_config)
1375
        model_kwargs = generation_config.update(**kwargs)
1376
        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
1377

1378
        if isinstance(eos_token_id, int):
1379
            eos_token_id = [eos_token_id]
1380

1381
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
1382
        if has_default_max_length and generation_config.max_new_tokens is None:
1383
            warnings.warn(
1384
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
1385
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
1386
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
1387
                UserWarning,
1388
            )
1389
        elif generation_config.max_new_tokens is not None:
1390
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
1391
            if not has_default_max_length:
1392
                logger.warn(
1393
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
1394
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
1395
                    "Please refer to the documentation for more information. "
1396
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
1397
                    UserWarning,
1398
                )
1399

1400
        if input_ids_seq_length >= generation_config.max_length:
1401
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
1402
            logger.warning(
1403
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
1404
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
1405
                " increasing `max_new_tokens`."
1406
            )
1407

1408
        # 2. Set generation parameters if not already defined
1409
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1410
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1411

1412
        logits_processor = self._get_logits_processor(
1413
            generation_config=generation_config,
1414
            input_ids_seq_length=input_ids_seq_length,
1415
            encoder_input_ids=input_ids,
1416
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
1417
            logits_processor=logits_processor,
1418
        )
1419

1420
        stopping_criteria = self._get_stopping_criteria(
1421
            generation_config=generation_config, stopping_criteria=stopping_criteria
1422
        )
1423
        logits_warper = self._get_logits_warper(generation_config)
1424

1425
        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
1426
        scores = None
1427
        while True:
1428
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1429
            # forward pass to get next token
1430
            outputs = self(
1431
                **model_inputs,
1432
                return_dict=True,
1433
                output_attentions=False,
1434
                output_hidden_states=False,
1435
            )
1436

1437
            next_token_logits = outputs.logits[:, -1, :]
1438

1439
            # pre-process distribution
1440
            next_token_scores = logits_processor(input_ids, next_token_logits)
1441
            next_token_scores = logits_warper(input_ids, next_token_scores)
1442

1443
            # sample
1444
            probs = nn.functional.softmax(next_token_scores, dim=-1)
1445
            if generation_config.do_sample:
1446
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
1447
            else:
1448
                next_tokens = torch.argmax(probs, dim=-1)
1449

1450
            # update generated ids, model inputs, and length for next step
1451
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1452
            model_kwargs = self._update_model_kwargs_for_generation(
1453
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1454
            )
1455
            unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
1456

1457
            # stop when each sentence is finished, or if we exceed the maximum length
1458
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
1459
                break
1460
            yield input_ids
1461

1462
    def quantize(self, bits: int, empty_init=False, **kwargs):
1463
        if bits == 0:
1464
            return
1465

1466
        from .quantization import quantize
1467

1468
        if self.quantized:
1469
            logger.info("Already quantized.")
1470
            return self
1471

1472
        self.quantized = True
1473

1474
        self.config.quantization_bit = bits
1475

1476
        self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs)
1477
        return self
1478

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.