paddlenlp

Форк
0
/
utils.py 
734 строки · 25.7 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from __future__ import annotations
15

16
import glob
17
import math
18
import os
19
import struct
20
from typing import Dict, Optional
21

22
import numpy as np
23
import paddle
24
import paddle.distributed as dist
25
import paddle.incubate.multiprocessing as mp
26
from paddle.distributed import fleet
27
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
28
from sklearn.metrics import accuracy_score
29

30
from paddlenlp.datasets import InTokensIterableDataset
31
from paddlenlp.trainer import Trainer, TrainerCallback
32
from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length
33
from paddlenlp.transformers import (
34
    AutoTokenizer,
35
    ChatGLMv2Tokenizer,
36
    LlamaForCausalLMPipe,
37
    PretrainedConfig,
38
)
39
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
40
from paddlenlp.utils.log import logger
41

42

43
def compute_metrics(eval_preds):
44

45
    flattened_preds = np.array(eval_preds.predictions).flatten()
46
    flattened_labels = np.array(eval_preds.label_ids).flatten()
47
    filtered_preds = flattened_preds[flattened_labels != -100]
48
    filtered_labels = flattened_labels[flattened_labels != -100]
49
    accuracy = accuracy_score(y_true=filtered_labels, y_pred=filtered_preds)
50
    return {
51
        "accuracy": accuracy,
52
    }
53

54

55
def get_prefix_tuning_params(model):
56
    if model.base_model_prefix == "chatglm":
57
        from paddlenlp.peft.prefix import chatglm_postprocess_past_key_value
58

59
        num_attention_heads = model.config.num_attention_heads
60
        num_hidden_layers = model.config.num_hidden_layers
61
        hidden_size = model.config.hidden_size
62
        postprocess_past_key_value = chatglm_postprocess_past_key_value
63
        multi_query_group_num = None
64
    elif model.base_model_prefix == "chatglm_v2":
65
        from paddlenlp.peft.prefix import chatglm_postprocess_past_key_value
66

67
        num_attention_heads = model.config.num_attention_heads
68
        num_hidden_layers = model.config.num_layers
69
        hidden_size = model.config.hidden_size
70
        postprocess_past_key_value = chatglm_postprocess_past_key_value
71
        multi_query_group_num = model.config.multi_query_group_num
72
    elif model.base_model_prefix == "bloom":
73
        from paddlenlp.peft.prefix import bloom_postprocess_past_key_value
74

75
        num_attention_heads = model.config.num_attention_heads
76
        num_hidden_layers = model.config.n_layer
77
        hidden_size = model.config.n_embed
78
        postprocess_past_key_value = bloom_postprocess_past_key_value
79
        multi_query_group_num = None
80
    elif model.base_model_prefix == "llama":
81
        from paddlenlp.peft.prefix import llama_postprocess_past_key_value
82

83
        num_attention_heads = model.config.n_head
84
        num_hidden_layers = model.config.n_layer
85
        hidden_size = model.config.hidden_size
86
        postprocess_past_key_value = llama_postprocess_past_key_value
87
        multi_query_group_num = None
88
    elif model.base_model_prefix == "qwen":
89
        from paddlenlp.peft.prefix import qwen_postprocess_past_key_value
90

91
        num_attention_heads = model.config.num_attention_heads
92
        num_hidden_layers = model.config.num_hidden_layers
93
        hidden_size = model.config.hidden_size
94
        postprocess_past_key_value = qwen_postprocess_past_key_value
95
        multi_query_group_num = None
96
    else:
97
        raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}. ")
98
    return dict(
99
        num_attention_heads=num_attention_heads,
100
        num_hidden_layers=num_hidden_layers,
101
        hidden_size=hidden_size,
102
        postprocess_past_key_value=postprocess_past_key_value,
103
        multi_query_group_num=multi_query_group_num,
104
    )
105

106

107
def get_lora_target_modules(model):
108
    # Not yet support RowParallelLinear
109
    if model.base_model_prefix == "chatglm":
110
        target_modules = [".*query_key_value.*", ".*dense.*", ".*dense_h_to_4h.*", ".*dense_4h_to_h.*"]
111
    elif model.base_model_prefix == "chatglm_v2":
112
        target_modules = [
113
            ".*query.*",
114
            ".*key.*",
115
            ".*value.*",
116
            ".*dense.*",
117
            ".*dense_h_to_4h.*",
118
            ".*dense_4h_to_h.*",
119
        ]
120
    elif model.base_model_prefix == "bloom":
121
        target_modules = [".*query_key_value.*", ".*dense.*", ".*dense_h_to_4h.*", ".*dense_4h_to_h.*"]
122
    elif model.base_model_prefix == "llama" or isinstance(model, LlamaForCausalLMPipe):
123
        target_modules = [
124
            ".*q_proj.*",
125
            ".*v_proj.*",
126
            ".*k_proj.*",
127
            ".*o_proj.*",
128
            ".*gate_proj.*",
129
            ".*down_proj.*",
130
            ".*up_proj.*",
131
        ]
132
    elif model.base_model_prefix == "opt":
133
        target_modules = [
134
            ".*project_in.*",
135
            ".*project_out.*",
136
            ".*q_proj.*",
137
            ".*k_proj.*",
138
            ".*v_proj.*",
139
            ".*qkv_proj.*",
140
            ".*out_proj.*",
141
            ".*linear1.*",
142
            ".*linear2.*",
143
        ]
144
    elif model.base_model_prefix == "qwen":
145
        target_modules = [
146
            ".*attn.c_attn.*",
147
            ".*attn.c_proj.*",
148
            ".*mlp.w1.*",
149
            ".*mlp.w2.*",
150
            ".*mlp.c_proj.*",
151
        ]
152
    elif model.base_model_prefix == "mixtral":
153
        target_modules = [
154
            ".*q_proj.*",
155
            ".*k_proj.*",
156
            ".*v_proj.*",
157
            ".*o_proj.*",
158
            ".*w1.*",
159
            ".*w2.*",
160
            ".*w3.*",
161
        ]
162
    else:
163
        raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
164
    return target_modules
165

166

167
class InTokensIterDatasetCallback(TrainerCallback):
168
    """
169
    A [`TrainerCallback`] that handles early stopping.
170

171
    """
172

173
    def on_step_end(self, args, state, control, **kwargs):
174
        train_dataloader = kwargs["train_dataloader"]
175
        if isinstance(train_dataloader.dataset, InTokensIterableDataset):
176
            dataset = train_dataloader.dataset
177
        elif isinstance(train_dataloader.dataset, IterableDatasetShard) and isinstance(
178
            train_dataloader.dataset.dataset, InTokensIterableDataset
179
        ):
180
            dataset = train_dataloader.dataset.dataset
181
        else:
182
            raise ValueError(
183
                "Unexpected dataset format: InTokensIterDatasetCallback expectes `paddlenlp.datasets.InTokensIterableDataset`"
184
            )
185
        if state.trial_params is None:
186
            state.trial_params = {}
187
        state.trial_params["intokens_global_step"] = dataset.intokens_global_step
188

189

190
class CausalLMTrainer(Trainer):
191
    def __init__(self, do_generation: bool, gen_args, data_args, **kwargs):
192
        super().__init__(**kwargs)
193
        self.do_generation = do_generation
194
        self.gen_args = gen_args
195
        self.data_args = data_args
196

197
    def prediction_step(
198
        self,
199
        model,
200
        inputs,
201
        prediction_loss_only: bool,
202
        ignore_keys=None,
203
    ):
204
        if prediction_loss_only or self.args.pipeline_parallel_degree > 1:
205
            return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
206
        elif not self.do_generation:
207
            loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
208
            # argmax here to avoid gather all logits, which is too memory-consuming.
209
            # keepdim in order to maintain the same shape as logits
210
            if isinstance(logits, (list, tuple)):
211
                logits = logits[0]
212
            return (loss, logits.argmax(axis=-1, keepdim=True), labels)
213

214
        loss = None
215

216
        model.eval()
217
        with paddle.no_grad():
218
            generated_tokens = model.generate(
219
                input_ids=inputs["input_ids"],
220
                attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None,
221
                position_ids=inputs["position_ids"] if "position_ids" in inputs else None,
222
                max_length=max(self.data_args.max_length - inputs["input_ids"].shape[-1], 1),
223
                decode_strategy="sampling",
224
                top_k=self.gen_args.top_k,
225
                top_p=self.gen_args.top_p,
226
                bos_token_id=self.tokenizer.bos_token_id,
227
                eos_token_id=self.tokenizer.eos_token_id,
228
                pad_token_id=self.tokenizer.pad_token_id,
229
                use_cache=True,
230
            )[0]
231
            all_preds = []
232
            for pred_tokens in generated_tokens:
233
                pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
234
                all_preds.append(pred_tokens)
235
            max_pred_length = max([len(x) for x in all_preds])
236
            for index, preds in enumerate(all_preds):
237
                all_preds[index] = preds + [-100] * (max_pred_length - len(preds))
238
            all_preds = paddle.to_tensor(all_preds)
239

240
            if "labels" in inputs:
241
                all_labels = paddle.to_tensor(inputs["labels"])
242
            else:
243
                all_labels = None
244

245
        return (loss, all_preds, all_labels)
246

247
    def log(self, logs: Dict[str, float], **kwargs) -> None:
248
        if "loss" in logs:
249
            logs["ppl"] = np.exp(logs["loss"])
250
        if "eval_loss" in logs:
251
            logs["eval_ppl"] = np.exp(logs["eval_loss"])
252

253
        super(CausalLMTrainer, self).log(logs, **kwargs)
254

255
    def get_ptq_dataloader(self, ptq_ds):
256
        if self.args.world_size <= 1:
257
            ptq_sampler = BatchSampler(
258
                dataset=ptq_ds,
259
                shuffle=True,
260
                batch_size=self.args.per_device_train_batch_size,
261
                drop_last=self.args.dataloader_drop_last,
262
            )
263
        else:
264
            ptq_sampler = DistributedBatchSampler(
265
                self.train_dataset,
266
                batch_size=self.args.per_device_train_batch_size,
267
                shuffle=True,
268
                num_replicas=self.args.dataset_world_size,
269
                rank=self.args.dataset_rank,
270
                drop_last=self.args.dataloader_drop_last,
271
            )
272
        ptq_dataloader = DataLoader(
273
            ptq_ds,
274
            batch_sampler=ptq_sampler,
275
            collate_fn=self.data_collator,
276
            num_workers=self.args.dataloader_num_workers,
277
        )
278
        return ptq_dataloader
279

280
    def ptq_loop(
281
        self,
282
        dataloader: DataLoader,
283
        description: str,
284
        max_eval_iters: Optional[int] = -1,
285
    ):
286
        if isinstance(dataloader, paddle.io.DataLoader):
287
            batch_size = dataloader.batch_sampler.batch_size
288
        else:
289
            raise ValueError("Only support for paddle.io.DataLoader")
290

291
        if has_length(dataloader):
292
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
293
            if max_eval_iters > 0:
294
                logger.info(f"  Total {description} steps = {max_eval_iters}")
295
            else:
296
                logger.info(f"  Total {description} steps = {len(dataloader)}")
297
        else:
298
            logger.info("  Num examples: Unknown")
299
            if max_eval_iters > 0:
300
                logger.info(f"  Total {description} steps = {max_eval_iters}")
301

302
        logger.info(f"  Pre device batch size = {batch_size}")
303
        logger.info(f"  Total Batch size = {batch_size * self.args.dataset_world_size}")
304
        self.model.eval()
305
        with paddle.no_grad():
306
            for step, inputs in enumerate(dataloader):
307
                self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None)
308
                if max_eval_iters > 0 and step >= max_eval_iters - 1:
309
                    break
310

311

312
def get_infer_model_path(input_dir, model_prefix):
313
    if dist.get_world_size() > 1:
314
        local_rank = dist.get_rank()
315
        return os.path.join(input_dir, "rank_{}".format(local_rank), model_prefix)
316
    else:
317
        return os.path.join(input_dir, model_prefix)
318

319

320
def generate_rank_mapping(output_filename):
321
    ring_id = -1
322
    try:
323
        hcg = fleet.get_hybrid_communicate_group()
324
        model_parallel_group = hcg.get_model_parallel_group()
325
        ring_id = model_parallel_group.id
326
    except Exception:
327
        pass
328

329
    if ring_id == -1:
330
        return
331

332
    world_size = dist.get_world_size()
333
    with open(output_filename, "w") as f:
334
        f.write("[ring_id -> ranks]\n")
335
        f.write(",".join(map(str, [0] + list(range(world_size)))) + "\n")
336
        f.write(",".join(map(str, [ring_id] + list(range(world_size)))) + "\n")
337

338
        f.write("[rank -> ring_ids]\n")
339
        for i in range(world_size):
340
            f.write("{},0,{}\n".format(i, ring_id))
341

342

343
def deserialize_from_file(fp):
344
    x_type = fp.read(1)
345
    x_type_out = struct.unpack("c", x_type)[0]
346
    # data
347
    data_list = []
348
    if x_type_out == b"0":
349
        data = fp.read(4)
350
        data_out = struct.unpack("f", data)[0]
351
        while data:
352
            data_out = struct.unpack("f", data)[0]
353
            data_list.append(data_out)
354
            data = fp.read(4)
355
    elif x_type_out == b"1":
356
        data = fp.read(8)
357
        while data:
358
            data_out = struct.unpack("l", data)[0]
359
            data_list.append(data_out)
360
            data = fp.read(8)
361
    elif x_type_out == b"2":
362
        data = fp.read(4)
363
        while data:
364
            data_out = struct.unpack("i", data)[0]
365
            data_list.append(data_out)
366
            data = fp.read(4)
367
    else:
368
        print("type error")
369
    data_arr = np.array(data_list)
370
    return data_arr
371

372

373
def get_alibi_slopes(num_heads):
374
    closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
375
    base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3)))
376
    powers = np.arange(1, 1 + closest_power_of_2)
377
    slopes = np.power(base, powers)
378

379
    if closest_power_of_2 != num_heads:
380
        extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3)))
381
        num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
382
        extra_powers = np.arange(1, 1 + 2 * num_remaining_heads, 2)
383
        slopes = np.concatante([slopes, np.power(extra_base, extra_powers)], axis=0)
384

385
    return slopes.astype("float32")
386

387

388
def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
389
    """Pad sequences to the max sequence length in batch."""
390
    max_len = max(map(len, insts))
391
    if pad_style == "left":
392
        inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts])
393
    else:
394
        inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
395

396
    if return_seq_len:
397
        seq_len = np.array([len(inst) for inst in insts])
398
        return inst_data.astype("int64").reshape([-1, max_len]), seq_len
399
    else:
400
        return inst_data.astype("int64").reshape([-1, max_len])
401

402

403
def dybatch_preprocess(
404
    tokenizer,
405
    texts: list[str],
406
    src_length: int,
407
    max_length: int,
408
    architectures: str,
409
    top_p: float,
410
    temperature: float,
411
    eos_token_id: int | list[list[int]],
412
    pre_caches_length: int = 0,
413
    benchmark: bool = False,
414
):
415
    """Pre-process generation inputs."""
416
    inputs = {}
417
    if "chatglmforcausallm" == architectures.lower():
418
        input_ids = []
419
        position_ids = []
420

421
        for text in texts:
422
            tokens = tokenizer(
423
                text,
424
                return_tensors="np",
425
                padding=True,
426
                max_length=src_length,
427
                # if use chat_template, it will not add special_tokens
428
                add_special_tokens=tokenizer.chat_template is None or isinstance(tokenizer, ChatGLMv2Tokenizer),
429
            )
430
            input_ids.append(tokens["input_ids"][0])
431
            position_ids.append(tokens["position_ids"][0])
432

433
        pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][0]
434
        inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
435
        bs = inputs["input_ids"].shape[0]
436
        max_len = max(map(len, input_ids))
437

438
        inst_data_pos = []
439
        for i in range(len(position_ids)):
440
            inst_data_pos.append(np.array([list(inst) + [0] * (max_len - len(inst)) for inst in position_ids[i]]))
441
        inputs["position_ids"] = paddle.to_tensor(np.array(inst_data_pos))
442
    elif "gpt" in architectures:
443
        input_ids = []
444
        if isinstance(texts, str):
445
            texts = [texts]
446

447
        for text in texts:
448
            tokens = tokenizer(
449
                text,
450
                return_tensors="np",
451
                padding=False,
452
                max_length=src_length,
453
                return_attention_mask=False,
454
                return_token_type_ids=False,
455
            )
456
            input_ids.append(tokens["input_ids"][0])
457

458
        pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
459
        inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
460
        bs = inputs["input_ids"].shape[0]
461
        max_len = max(map(len, input_ids))
462

463
        position_ids = paddle.arange(sum(seq_len), dtype="int64")
464
        pre_len = seq_len[0]
465
        for length in seq_len[1:]:
466
            position_ids[pre_len : length + pre_len] = position_ids[pre_len : length + pre_len] - pre_len
467
            pre_len += length
468
        inputs["position_ids"] = position_ids
469
    else:
470
        input_ids = []
471
        if isinstance(texts, str):
472
            texts = [texts]
473

474
        for text in texts:
475
            tokens = tokenizer(
476
                text,
477
                return_tensors="np",
478
                padding=False,
479
                max_length=src_length,
480
                return_attention_mask=False,
481
                return_token_type_ids=False,
482
                add_special_tokens=tokenizer.chat_template is None or isinstance(tokenizer, ChatGLMv2Tokenizer),
483
            )
484
            input_ids.append(tokens["input_ids"][0])
485

486
        pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
487
        inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
488
        bs = inputs["input_ids"].shape[0]
489
        max_len = max(map(len, input_ids))
490

491
        position_ids = paddle.zeros(shape=[bs, max_length + src_length], dtype="int64")
492

493
        for i in range(bs):
494
            position_ids[i, pre_caches_length : pre_caches_length + seq_len[i]] = paddle.arange(seq_len[i])
495
        inputs["position_ids"] = position_ids
496

497
    tgt_ids = [input[-1:] for input in input_ids]
498
    tgt_pos = []
499
    for i, valid_len in enumerate(map(len, input_ids)):
500
        tgt_pos.append(valid_len - 1)
501

502
    step_idx = [
503
        0,
504
    ] * bs
505
    tgt_pos = np.array(tgt_pos).astype("int64")
506

507
    if isinstance(eos_token_id, int):
508
        eos_token_id = [eos_token_id]
509

510
    inputs["eos_token_id"] = np.array(eos_token_id * bs).reshape(-1, 1).astype("int64")
511

512
    inputs["top_p"] = (
513
        np.array(
514
            [
515
                top_p,
516
            ]
517
            * bs
518
        )
519
        .reshape(-1, 1)
520
        .astype("float32")
521
    )
522
    inputs["temperature"] = (
523
        np.array(
524
            [
525
                temperature,
526
            ]
527
            * bs
528
        )
529
        .reshape(-1, 1)
530
        .astype("float32")
531
    )
532
    inputs["seq_len_encoder"] = seq_len.astype("int32").reshape(-1, 1)
533
    inputs["seq_len_decoder"] = (seq_len + pre_caches_length).astype("int32").reshape(-1, 1)
534
    inputs["step_idx"] = np.array(step_idx).astype("int64").reshape(-1, 1)
535
    inputs["tgt_ids"] = np.array(tgt_ids).astype("int64").reshape(-1, 1)
536
    inputs["tgt_pos"] = tgt_pos.reshape(-1, 1)
537
    inputs["max_length"] = np.array(max_length - pre_caches_length).astype("int64").reshape((-1, 1))
538
    inputs["min_length"] = (
539
        np.array(
540
            [
541
                1
542
                if not benchmark
543
                else max_length
544
                - pre_caches_length,  # Note(Zhengzekang): When in benchmark mode, we need to set a fixed decode length.
545
            ]
546
            * bs
547
        )
548
        .astype("int64")
549
        .reshape((-1, 1))
550
    )
551
    inputs["penalty_score"] = (
552
        np.array(
553
            [
554
                1.0,
555
            ]
556
            * bs
557
        )
558
        .astype("float32")
559
        .reshape((-1, 1))
560
    )
561
    inputs["frequency_score"] = (
562
        np.array(
563
            [
564
                0.0,
565
            ]
566
            * bs
567
        )
568
        .astype("float32")
569
        .reshape((-1, 1))
570
    )
571
    inputs["presence_score"] = (
572
        np.array(
573
            [
574
                0.0,
575
            ]
576
            * bs
577
        )
578
        .astype("float32")
579
        .reshape((-1, 1))
580
    )
581
    inputs["stop_flags"] = (
582
        np.array(
583
            [
584
                0,
585
            ]
586
            * bs
587
        )
588
        .astype("bool")
589
        .reshape((-1, 1))
590
    )
591
    inputs["stop_nums"] = np.array([bs]).astype("int64")
592
    return inputs
593

594

595
def load_real_time_tokens():
596
    tokens = []
597
    files = glob.glob(os.path.join("./real_time_save.*"))
598
    for j in range(1, len(files) + 1):
599
        filename = "./real_time_save.temp_ids_rank_0_step_{}".format(j)
600
        if not os.path.exists(filename):
601
            break
602
        fp = open(filename, "rb+")
603
        fp.read(1)
604
        data_list = deserialize_from_file(fp)
605
        fp.close()
606
        tokens.append(np.array(data_list).reshape(-1, 1))
607
    os.system("rm -f ./real_time_save.temp_ids_rank_*")
608
    tokens = np.concatenate(tokens, axis=1)
609
    return tokens
610

611

612
def init_chat_template(
613
    tokenizer: PretrainedTokenizer, model_name_or_path: str, chat_template_file: Optional[str] = None
614
):
615
    """init chat template for the given tokenizer.
616

617
        If is None, it will not use `chat_template.json`;
618
        If is equal with `model_name_or_path`, it will use the default loading;
619
        If is directory, it will find the `chat_template.json` under the directory;
620
        If is file, it will load it.
621

622
    Args:
623
        tokenizer (PretrainedTokenizer): the instance of tokenizer
624
        model_name_or_path (str): _description_
625
        chat_template_file (Optional[str], optional): _description_. Defaults to None.
626
    """
627
    # 1. use the default chat_template file
628
    if chat_template_file is None:
629
        return
630

631
    if str(chat_template_file).lower() == "none":
632
        # delete the chat_template from tokenizer if not use chat_template.
633
        # why do this: it will load the `chat_template.json` file by default
634
        tokenizer.chat_template = None
635
        return
636

637
    # it will load the `chat_template.json` file by default, so do nothing
638
    if chat_template_file == model_name_or_path:
639
        if tokenizer.chat_template is None:
640
            logger.warning(f"there is not `chat_template.json` file in the `{model_name_or_path}`")
641
        return
642

643
    if os.path.isdir(chat_template_file):
644
        local_chat_template_file_path = os.path.join(chat_template_file, "chat_template.json")
645
        if os.path.exists(local_chat_template_file_path):
646
            chat_template_file = local_chat_template_file_path
647
        else:
648
            logger.warning(f"there is not `chat_template.json` file in the `{model_name_or_path}`")
649
            return
650

651
    if not os.path.exists(chat_template_file):
652
        logger.warning(f"there is not `chat_template.json` file from path<`{model_name_or_path}`>")
653
        return
654

655
    logger.info(f"loading `chat_template.json` from `{chat_template_file}`")
656
    tokenizer.init_chat_template(chat_template_file)
657

658

659
def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]:
660
    names = [
661
        "max_position_embeddings",  # most of models
662
        "max_sequence_length",  # GLM model
663
        "seq_length",  # llama model
664
    ]
665
    for name in names:
666
        max_length = config.get(name, None)
667
        if max_length is not None:
668
            return max_length
669
    return None
670

671

672
def get_default_max_decoding_length(config: PretrainedConfig, default: int = 1024) -> int:
673
    """get the default max decoding length from config.
674

675
    Args:
676
        config (PretrainedConfig): the instance of PretrainedConfig
677
        default (int): the default value of max decoding length
678

679
    Returns:
680
        int: the default max_length of decoding length
681
    """
682
    max_position_embeddings = get_model_max_position_embeddings(config)
683
    if max_position_embeddings is None:
684
        return default
685
    return max_position_embeddings // 4
686

687

688
def get_default_max_encoding_length(config: PretrainedConfig, default: int = 1024) -> int:
689
    """get the default max encoding length from config.
690

691
    Args:
692
        config (PretrainedConfig): the instance of PretrainedConfig
693
        default (int): the default value of max encoding length
694

695
    Returns:
696
        int: the default max_length of encoding length
697
    """
698

699
    max_position_embeddings = get_model_max_position_embeddings(config)
700
    if max_position_embeddings is None:
701
        return default
702
    return max_position_embeddings // 4 * 3
703

704

705
def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue):
706
    tokenizer = AutoTokenizer.from_pretrained(
707
        model_name_or_path,
708
    )
709

710
    paddle.device.set_device("cpu")
711
    outputs = []
712
    output_tensor = tensor_queue.get(timeout=1)
713

714
    logger.info("Start read result message")
715
    logger.info(f"Current path is {os.getcwd()}")
716

717
    from paddlenlp_ops import get_output
718

719
    while True:
720
        get_output(output_tensor, 0, True)
721
        if output_tensor[0, 0] == -2:  # read none
722
            continue
723
        bsz = output_tensor[1, 0].numpy()
724
        output_numpy = output_tensor[2 : bsz + 2].numpy()
725
        output_numpy[output_numpy == -1] = 2
726
        outputs.append(output_numpy)
727
        if output_tensor[0, 0] == -1:
728
            break
729
    output = np.concatenate(outputs, axis=1).tolist()
730
    seqs = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)
731
    for i, seq in enumerate(seqs):
732
        result_queue.put([i, seq])
733

734
    logger.info("Finish read result message")
735

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.