llama-factory
52 строки · 1.7 Кб
1# coding=utf-8
2# Calculates the distribution of the input lengths in the dataset.
3# Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
4
5from collections import defaultdict6from typing import Optional7
8import fire9from tqdm import tqdm10
11from llmtuner.data import get_dataset12from llmtuner.hparams import get_train_args13from llmtuner.model import load_model_and_tokenizer14
15
16def length_cdf(17model_name_or_path: str,18dataset: Optional[str] = "alpaca_en",19dataset_dir: Optional[str] = "data",20template: Optional[str] = "default",21interval: Optional[int] = 1000,22):23model_args, data_args, training_args, finetuning_args, _ = get_train_args(24dict(25stage="sft",26model_name_or_path=model_name_or_path,27dataset=dataset,28dataset_dir=dataset_dir,29template=template,30cutoff_len=1_000_000,31output_dir="dummy_dir",32overwrite_cache=True,33)34)35_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)36trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")37total_num = len(trainset)38length_dict = defaultdict(int)39for sample in tqdm(trainset["input_ids"]):40length_dict[len(sample) // interval * interval] += 141
42length_tuples = list(length_dict.items())43length_tuples.sort()44count_accu, prob_accu = 0, 045for length, count in length_tuples:46count_accu += count47prob_accu += count / total_num * 10048print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))49
50
51if __name__ == "__main__":52fire.Fire(length_cdf)53