14
from __future__ import annotations
20
from typing import Dict, Optional
24
import paddle.distributed as dist
25
import paddle.incubate.multiprocessing as mp
26
from paddle.distributed import fleet
27
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
28
from sklearn.metrics import accuracy_score
30
from paddlenlp.datasets import InTokensIterableDataset
31
from paddlenlp.trainer import Trainer, TrainerCallback
32
from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length
33
from paddlenlp.transformers import (
39
from paddlenlp.transformers.tokenizer_utils import PretrainedTokenizer
40
from paddlenlp.utils.log import logger
43
def compute_metrics(eval_preds):
45
flattened_preds = np.array(eval_preds.predictions).flatten()
46
flattened_labels = np.array(eval_preds.label_ids).flatten()
47
filtered_preds = flattened_preds[flattened_labels != -100]
48
filtered_labels = flattened_labels[flattened_labels != -100]
49
accuracy = accuracy_score(y_true=filtered_labels, y_pred=filtered_preds)
55
def get_prefix_tuning_params(model):
56
if model.base_model_prefix == "chatglm":
57
from paddlenlp.peft.prefix import chatglm_postprocess_past_key_value
59
num_attention_heads = model.config.num_attention_heads
60
num_hidden_layers = model.config.num_hidden_layers
61
hidden_size = model.config.hidden_size
62
postprocess_past_key_value = chatglm_postprocess_past_key_value
63
multi_query_group_num = None
64
elif model.base_model_prefix == "chatglm_v2":
65
from paddlenlp.peft.prefix import chatglm_postprocess_past_key_value
67
num_attention_heads = model.config.num_attention_heads
68
num_hidden_layers = model.config.num_layers
69
hidden_size = model.config.hidden_size
70
postprocess_past_key_value = chatglm_postprocess_past_key_value
71
multi_query_group_num = model.config.multi_query_group_num
72
elif model.base_model_prefix == "bloom":
73
from paddlenlp.peft.prefix import bloom_postprocess_past_key_value
75
num_attention_heads = model.config.num_attention_heads
76
num_hidden_layers = model.config.n_layer
77
hidden_size = model.config.n_embed
78
postprocess_past_key_value = bloom_postprocess_past_key_value
79
multi_query_group_num = None
80
elif model.base_model_prefix == "llama":
81
from paddlenlp.peft.prefix import llama_postprocess_past_key_value
83
num_attention_heads = model.config.n_head
84
num_hidden_layers = model.config.n_layer
85
hidden_size = model.config.hidden_size
86
postprocess_past_key_value = llama_postprocess_past_key_value
87
multi_query_group_num = None
88
elif model.base_model_prefix == "qwen":
89
from paddlenlp.peft.prefix import qwen_postprocess_past_key_value
91
num_attention_heads = model.config.num_attention_heads
92
num_hidden_layers = model.config.num_hidden_layers
93
hidden_size = model.config.hidden_size
94
postprocess_past_key_value = qwen_postprocess_past_key_value
95
multi_query_group_num = None
97
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}. ")
99
num_attention_heads=num_attention_heads,
100
num_hidden_layers=num_hidden_layers,
101
hidden_size=hidden_size,
102
postprocess_past_key_value=postprocess_past_key_value,
103
multi_query_group_num=multi_query_group_num,
107
def get_lora_target_modules(model):
109
if model.base_model_prefix == "chatglm":
110
target_modules = [".*query_key_value.*", ".*dense.*", ".*dense_h_to_4h.*", ".*dense_4h_to_h.*"]
111
elif model.base_model_prefix == "chatglm_v2":
120
elif model.base_model_prefix == "bloom":
121
target_modules = [".*query_key_value.*", ".*dense.*", ".*dense_h_to_4h.*", ".*dense_4h_to_h.*"]
122
elif model.base_model_prefix == "llama" or isinstance(model, LlamaForCausalLMPipe):
132
elif model.base_model_prefix == "opt":
144
elif model.base_model_prefix == "qwen":
152
elif model.base_model_prefix == "mixtral":
163
raise ValueError(f"Unknown base_model_prefix: {model.base_model_prefix}.")
164
return target_modules
167
class InTokensIterDatasetCallback(TrainerCallback):
169
A [`TrainerCallback`] that handles early stopping.
173
def on_step_end(self, args, state, control, **kwargs):
174
train_dataloader = kwargs["train_dataloader"]
175
if isinstance(train_dataloader.dataset, InTokensIterableDataset):
176
dataset = train_dataloader.dataset
177
elif isinstance(train_dataloader.dataset, IterableDatasetShard) and isinstance(
178
train_dataloader.dataset.dataset, InTokensIterableDataset
180
dataset = train_dataloader.dataset.dataset
183
"Unexpected dataset format: InTokensIterDatasetCallback expectes `paddlenlp.datasets.InTokensIterableDataset`"
185
if state.trial_params is None:
186
state.trial_params = {}
187
state.trial_params["intokens_global_step"] = dataset.intokens_global_step
190
class CausalLMTrainer(Trainer):
191
def __init__(self, do_generation: bool, gen_args, data_args, **kwargs):
192
super().__init__(**kwargs)
193
self.do_generation = do_generation
194
self.gen_args = gen_args
195
self.data_args = data_args
201
prediction_loss_only: bool,
204
if prediction_loss_only or self.args.pipeline_parallel_degree > 1:
205
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
206
elif not self.do_generation:
207
loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
210
if isinstance(logits, (list, tuple)):
212
return (loss, logits.argmax(axis=-1, keepdim=True), labels)
217
with paddle.no_grad():
218
generated_tokens = model.generate(
219
input_ids=inputs["input_ids"],
220
attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None,
221
position_ids=inputs["position_ids"] if "position_ids" in inputs else None,
222
max_length=max(self.data_args.max_length - inputs["input_ids"].shape[-1], 1),
223
decode_strategy="sampling",
224
top_k=self.gen_args.top_k,
225
top_p=self.gen_args.top_p,
226
bos_token_id=self.tokenizer.bos_token_id,
227
eos_token_id=self.tokenizer.eos_token_id,
228
pad_token_id=self.tokenizer.pad_token_id,
232
for pred_tokens in generated_tokens:
233
pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist()
234
all_preds.append(pred_tokens)
235
max_pred_length = max([len(x) for x in all_preds])
236
for index, preds in enumerate(all_preds):
237
all_preds[index] = preds + [-100] * (max_pred_length - len(preds))
238
all_preds = paddle.to_tensor(all_preds)
240
if "labels" in inputs:
241
all_labels = paddle.to_tensor(inputs["labels"])
245
return (loss, all_preds, all_labels)
247
def log(self, logs: Dict[str, float], **kwargs) -> None:
249
logs["ppl"] = np.exp(logs["loss"])
250
if "eval_loss" in logs:
251
logs["eval_ppl"] = np.exp(logs["eval_loss"])
253
super(CausalLMTrainer, self).log(logs, **kwargs)
255
def get_ptq_dataloader(self, ptq_ds):
256
if self.args.world_size <= 1:
257
ptq_sampler = BatchSampler(
260
batch_size=self.args.per_device_train_batch_size,
261
drop_last=self.args.dataloader_drop_last,
264
ptq_sampler = DistributedBatchSampler(
266
batch_size=self.args.per_device_train_batch_size,
268
num_replicas=self.args.dataset_world_size,
269
rank=self.args.dataset_rank,
270
drop_last=self.args.dataloader_drop_last,
272
ptq_dataloader = DataLoader(
274
batch_sampler=ptq_sampler,
275
collate_fn=self.data_collator,
276
num_workers=self.args.dataloader_num_workers,
278
return ptq_dataloader
282
dataloader: DataLoader,
284
max_eval_iters: Optional[int] = -1,
286
if isinstance(dataloader, paddle.io.DataLoader):
287
batch_size = dataloader.batch_sampler.batch_size
289
raise ValueError("Only support for paddle.io.DataLoader")
291
if has_length(dataloader):
292
logger.info(f" Num examples = {self.num_examples(dataloader)}")
293
if max_eval_iters > 0:
294
logger.info(f" Total {description} steps = {max_eval_iters}")
296
logger.info(f" Total {description} steps = {len(dataloader)}")
298
logger.info(" Num examples: Unknown")
299
if max_eval_iters > 0:
300
logger.info(f" Total {description} steps = {max_eval_iters}")
302
logger.info(f" Pre device batch size = {batch_size}")
303
logger.info(f" Total Batch size = {batch_size * self.args.dataset_world_size}")
305
with paddle.no_grad():
306
for step, inputs in enumerate(dataloader):
307
self.prediction_step(model=self.model, inputs=inputs, prediction_loss_only=True, ignore_keys=None)
308
if max_eval_iters > 0 and step >= max_eval_iters - 1:
312
def get_infer_model_path(input_dir, model_prefix):
313
if dist.get_world_size() > 1:
314
local_rank = dist.get_rank()
315
return os.path.join(input_dir, "rank_{}".format(local_rank), model_prefix)
317
return os.path.join(input_dir, model_prefix)
320
def generate_rank_mapping(output_filename):
323
hcg = fleet.get_hybrid_communicate_group()
324
model_parallel_group = hcg.get_model_parallel_group()
325
ring_id = model_parallel_group.id
332
world_size = dist.get_world_size()
333
with open(output_filename, "w") as f:
334
f.write("[ring_id -> ranks]\n")
335
f.write(",".join(map(str, [0] + list(range(world_size)))) + "\n")
336
f.write(",".join(map(str, [ring_id] + list(range(world_size)))) + "\n")
338
f.write("[rank -> ring_ids]\n")
339
for i in range(world_size):
340
f.write("{},0,{}\n".format(i, ring_id))
343
def deserialize_from_file(fp):
345
x_type_out = struct.unpack("c", x_type)[0]
348
if x_type_out == b"0":
350
data_out = struct.unpack("f", data)[0]
352
data_out = struct.unpack("f", data)[0]
353
data_list.append(data_out)
355
elif x_type_out == b"1":
358
data_out = struct.unpack("l", data)[0]
359
data_list.append(data_out)
361
elif x_type_out == b"2":
364
data_out = struct.unpack("i", data)[0]
365
data_list.append(data_out)
369
data_arr = np.array(data_list)
373
def get_alibi_slopes(num_heads):
374
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
375
base = 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3)))
376
powers = np.arange(1, 1 + closest_power_of_2)
377
slopes = np.power(base, powers)
379
if closest_power_of_2 != num_heads:
380
extra_base = 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3)))
381
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
382
extra_powers = np.arange(1, 1 + 2 * num_remaining_heads, 2)
383
slopes = np.concatante([slopes, np.power(extra_base, extra_powers)], axis=0)
385
return slopes.astype("float32")
388
def pad_batch_data(insts, pad_id=0, return_seq_len=False, pad_style="right"):
389
"""Pad sequences to the max sequence length in batch."""
390
max_len = max(map(len, insts))
391
if pad_style == "left":
392
inst_data = np.array([[pad_id] * (max_len - len(inst)) + list(inst) for inst in insts])
394
inst_data = np.array([list(inst) + [pad_id] * (max_len - len(inst)) for inst in insts])
397
seq_len = np.array([len(inst) for inst in insts])
398
return inst_data.astype("int64").reshape([-1, max_len]), seq_len
400
return inst_data.astype("int64").reshape([-1, max_len])
403
def dybatch_preprocess(
411
eos_token_id: int | list[list[int]],
412
pre_caches_length: int = 0,
413
benchmark: bool = False,
415
"""Pre-process generation inputs."""
417
if "chatglmforcausallm" == architectures.lower():
426
max_length=src_length,
428
add_special_tokens=tokenizer.chat_template is None or isinstance(tokenizer, ChatGLMv2Tokenizer),
430
input_ids.append(tokens["input_ids"][0])
431
position_ids.append(tokens["position_ids"][0])
433
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][0]
434
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
435
bs = inputs["input_ids"].shape[0]
436
max_len = max(map(len, input_ids))
439
for i in range(len(position_ids)):
440
inst_data_pos.append(np.array([list(inst) + [0] * (max_len - len(inst)) for inst in position_ids[i]]))
441
inputs["position_ids"] = paddle.to_tensor(np.array(inst_data_pos))
442
elif "gpt" in architectures:
444
if isinstance(texts, str):
452
max_length=src_length,
453
return_attention_mask=False,
454
return_token_type_ids=False,
456
input_ids.append(tokens["input_ids"][0])
458
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
459
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
460
bs = inputs["input_ids"].shape[0]
461
max_len = max(map(len, input_ids))
463
position_ids = paddle.arange(sum(seq_len), dtype="int64")
465
for length in seq_len[1:]:
466
position_ids[pre_len : length + pre_len] = position_ids[pre_len : length + pre_len] - pre_len
468
inputs["position_ids"] = position_ids
471
if isinstance(texts, str):
479
max_length=src_length,
480
return_attention_mask=False,
481
return_token_type_ids=False,
482
add_special_tokens=tokenizer.chat_template is None or isinstance(tokenizer, ChatGLMv2Tokenizer),
484
input_ids.append(tokens["input_ids"][0])
486
pad_token_id = tokenizer([tokenizer.pad_token], return_tensors="np")["input_ids"][0][-1]
487
inputs["input_ids"], seq_len = pad_batch_data(input_ids, pad_id=pad_token_id, return_seq_len=True)
488
bs = inputs["input_ids"].shape[0]
489
max_len = max(map(len, input_ids))
491
position_ids = paddle.zeros(shape=[bs, max_length + src_length], dtype="int64")
494
position_ids[i, pre_caches_length : pre_caches_length + seq_len[i]] = paddle.arange(seq_len[i])
495
inputs["position_ids"] = position_ids
497
tgt_ids = [input[-1:] for input in input_ids]
499
for i, valid_len in enumerate(map(len, input_ids)):
500
tgt_pos.append(valid_len - 1)
505
tgt_pos = np.array(tgt_pos).astype("int64")
507
if isinstance(eos_token_id, int):
508
eos_token_id = [eos_token_id]
510
inputs["eos_token_id"] = np.array(eos_token_id * bs).reshape(-1, 1).astype("int64")
522
inputs["temperature"] = (
532
inputs["seq_len_encoder"] = seq_len.astype("int32").reshape(-1, 1)
533
inputs["seq_len_decoder"] = (seq_len + pre_caches_length).astype("int32").reshape(-1, 1)
534
inputs["step_idx"] = np.array(step_idx).astype("int64").reshape(-1, 1)
535
inputs["tgt_ids"] = np.array(tgt_ids).astype("int64").reshape(-1, 1)
536
inputs["tgt_pos"] = tgt_pos.reshape(-1, 1)
537
inputs["max_length"] = np.array(max_length - pre_caches_length).astype("int64").reshape((-1, 1))
538
inputs["min_length"] = (
551
inputs["penalty_score"] = (
561
inputs["frequency_score"] = (
571
inputs["presence_score"] = (
581
inputs["stop_flags"] = (
591
inputs["stop_nums"] = np.array([bs]).astype("int64")
595
def load_real_time_tokens():
597
files = glob.glob(os.path.join("./real_time_save.*"))
598
for j in range(1, len(files) + 1):
599
filename = "./real_time_save.temp_ids_rank_0_step_{}".format(j)
600
if not os.path.exists(filename):
602
fp = open(filename, "rb+")
604
data_list = deserialize_from_file(fp)
606
tokens.append(np.array(data_list).reshape(-1, 1))
607
os.system("rm -f ./real_time_save.temp_ids_rank_*")
608
tokens = np.concatenate(tokens, axis=1)
612
def init_chat_template(
613
tokenizer: PretrainedTokenizer, model_name_or_path: str, chat_template_file: Optional[str] = None
615
"""init chat template for the given tokenizer.
617
If is None, it will not use `chat_template.json`;
618
If is equal with `model_name_or_path`, it will use the default loading;
619
If is directory, it will find the `chat_template.json` under the directory;
620
If is file, it will load it.
623
tokenizer (PretrainedTokenizer): the instance of tokenizer
624
model_name_or_path (str): _description_
625
chat_template_file (Optional[str], optional): _description_. Defaults to None.
628
if chat_template_file is None:
631
if str(chat_template_file).lower() == "none":
634
tokenizer.chat_template = None
638
if chat_template_file == model_name_or_path:
639
if tokenizer.chat_template is None:
640
logger.warning(f"there is not `chat_template.json` file in the `{model_name_or_path}`")
643
if os.path.isdir(chat_template_file):
644
local_chat_template_file_path = os.path.join(chat_template_file, "chat_template.json")
645
if os.path.exists(local_chat_template_file_path):
646
chat_template_file = local_chat_template_file_path
648
logger.warning(f"there is not `chat_template.json` file in the `{model_name_or_path}`")
651
if not os.path.exists(chat_template_file):
652
logger.warning(f"there is not `chat_template.json` file from path<`{model_name_or_path}`>")
655
logger.info(f"loading `chat_template.json` from `{chat_template_file}`")
656
tokenizer.init_chat_template(chat_template_file)
659
def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]:
661
"max_position_embeddings",
662
"max_sequence_length",
666
max_length = config.get(name, None)
667
if max_length is not None:
672
def get_default_max_decoding_length(config: PretrainedConfig, default: int = 1024) -> int:
673
"""get the default max decoding length from config.
676
config (PretrainedConfig): the instance of PretrainedConfig
677
default (int): the default value of max decoding length
680
int: the default max_length of decoding length
682
max_position_embeddings = get_model_max_position_embeddings(config)
683
if max_position_embeddings is None:
685
return max_position_embeddings // 4
688
def get_default_max_encoding_length(config: PretrainedConfig, default: int = 1024) -> int:
689
"""get the default max encoding length from config.
692
config (PretrainedConfig): the instance of PretrainedConfig
693
default (int): the default value of max encoding length
696
int: the default max_length of encoding length
699
max_position_embeddings = get_model_max_position_embeddings(config)
700
if max_position_embeddings is None:
702
return max_position_embeddings // 4 * 3
705
def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue):
706
tokenizer = AutoTokenizer.from_pretrained(
710
paddle.device.set_device("cpu")
712
output_tensor = tensor_queue.get(timeout=1)
714
logger.info("Start read result message")
715
logger.info(f"Current path is {os.getcwd()}")
717
from paddlenlp_ops import get_output
720
get_output(output_tensor, 0, True)
721
if output_tensor[0, 0] == -2:
723
bsz = output_tensor[1, 0].numpy()
724
output_numpy = output_tensor[2 : bsz + 2].numpy()
725
output_numpy[output_numpy == -1] = 2
726
outputs.append(output_numpy)
727
if output_tensor[0, 0] == -1:
729
output = np.concatenate(outputs, axis=1).tolist()
730
seqs = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)
731
for i, seq in enumerate(seqs):
732
result_queue.put([i, seq])
734
logger.info("Finish read result message")