2
from typing import TYPE_CHECKING, Any, Dict, List, Union
4
from datasets import concatenate_datasets, interleave_datasets, load_dataset
6
from llmtuner.data.utils import checksum, EXT2TYPE
7
from llmtuner.extras.logging import get_logger
10
from datasets import Dataset, IterableDataset
11
from llmtuner.hparams import ModelArguments, DataArguments
14
logger = get_logger(__name__)
18
model_args: "ModelArguments",
19
data_args: "DataArguments"
20
) -> Union["Dataset", "IterableDataset"]:
21
max_samples = data_args.max_samples
22
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
24
for dataset_attr in data_args.dataset_list:
25
logger.info("Loading dataset {}...".format(dataset_attr))
27
data_path, data_name, data_dir, data_files = None, None, None, None
28
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
29
data_path = dataset_attr.dataset_name
30
data_name = dataset_attr.subset
31
data_dir = dataset_attr.folder
32
elif dataset_attr.load_from == "script":
33
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
34
data_name = dataset_attr.subset
35
elif dataset_attr.load_from == "file":
37
local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
38
if os.path.isdir(local_path): # is directory
39
for file_name in os.listdir(local_path):
40
data_files.append(os.path.join(local_path, file_name))
42
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
44
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
45
elif os.path.isfile(local_path): # is file
46
data_files.append(local_path)
47
data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
49
raise ValueError("File not found.")
51
assert data_path, "File extension must be txt, csv, json or jsonl."
52
checksum(data_files, dataset_attr.dataset_sha1)
54
raise NotImplementedError
56
if dataset_attr.load_from == "ms_hub":
58
from modelscope import MsDataset # type: ignore
59
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
61
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
62
dataset = MsDataset.load(
63
dataset_name=data_path,
64
subset_name=data_name,
66
data_files=data_files,
67
split=data_args.split,
69
token=model_args.ms_hub_token,
70
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
73
raise ImportError("Please install modelscope via `pip install modelscope -U`")
75
dataset = load_dataset(
79
data_files=data_files,
80
split=data_args.split,
81
cache_dir=model_args.cache_dir,
82
token=model_args.hf_hub_token,
83
streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
86
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
87
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
89
if max_samples is not None: # truncate dataset
90
dataset = dataset.select(range(min(len(dataset), max_samples)))
92
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
93
# convert dataset from sharegpt format to alpaca format
94
outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
95
for i, msg_list in enumerate(examples[dataset_attr.messages]):
96
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
97
if len(msg_list) == 0:
101
user_role, assistant_role = None, None
102
for idx in range(0, len(msg_list), 2):
103
if user_role is None and assistant_role is None:
104
user_role = msg_list[idx][dataset_attr.role]
105
assistant_role = msg_list[idx + 1][dataset_attr.role]
108
msg_list[idx][dataset_attr.role] != user_role
109
or msg_list[idx+1][dataset_attr.role] != assistant_role
111
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
112
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
114
if len(msg_pairs) != 0:
115
outputs["prompt"].append(msg_pairs[-1][0])
116
outputs["query"].append("")
117
outputs["response"].append(msg_pairs[-1][1])
118
outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
119
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
123
if dataset_attr.formatting == "sharegpt": # convert format
124
column_names = list(next(iter(dataset)).keys())
126
if not data_args.streaming:
128
num_proc=data_args.preprocessing_num_workers,
129
load_from_cache_file=(not data_args.overwrite_cache),
130
desc="Converting format of dataset"
133
dataset = dataset.map(
136
remove_columns=column_names,
140
for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
141
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
142
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
144
all_datasets.append(dataset)
146
if len(data_args.dataset_list) == 1:
147
return all_datasets[0]
148
elif data_args.mix_strategy == "concat":
149
if data_args.streaming:
150
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
151
return concatenate_datasets(all_datasets)
152
elif data_args.mix_strategy.startswith("interleave"):
153
if not data_args.streaming:
154
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
155
return interleave_datasets(
156
datasets=all_datasets,
157
probabilities=data_args.interleave_probs,
159
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
162
raise ValueError("Unknown mixing strategy.")