aurora

Форк
0
/
loader.py 
162 строки · 7.3 Кб
1
import os
2
from typing import TYPE_CHECKING, Any, Dict, List, Union
3

4
from datasets import concatenate_datasets, interleave_datasets, load_dataset
5

6
from llmtuner.data.utils import checksum, EXT2TYPE
7
from llmtuner.extras.logging import get_logger
8

9
if TYPE_CHECKING:
10
    from datasets import Dataset, IterableDataset
11
    from llmtuner.hparams import ModelArguments, DataArguments
12

13

14
logger = get_logger(__name__)
15

16

17
def get_dataset(
18
    model_args: "ModelArguments",
19
    data_args: "DataArguments"
20
) -> Union["Dataset", "IterableDataset"]:
21
    max_samples = data_args.max_samples
22
    all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
23

24
    for dataset_attr in data_args.dataset_list:
25
        logger.info("Loading dataset {}...".format(dataset_attr))
26

27
        data_path, data_name, data_dir, data_files = None, None, None, None
28
        if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
29
            data_path = dataset_attr.dataset_name
30
            data_name = dataset_attr.subset
31
            data_dir = dataset_attr.folder
32
        elif dataset_attr.load_from == "script":
33
            data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
34
            data_name = dataset_attr.subset
35
        elif dataset_attr.load_from == "file":
36
            data_files = []
37
            local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
38
            if os.path.isdir(local_path): # is directory
39
                for file_name in os.listdir(local_path):
40
                    data_files.append(os.path.join(local_path, file_name))
41
                    if data_path is None:
42
                        data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
43
                    else:
44
                        assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
45
            elif os.path.isfile(local_path): # is file
46
                data_files.append(local_path)
47
                data_path = EXT2TYPE.get(local_path.split(".")[-1], None)
48
            else:
49
                raise ValueError("File not found.")
50

51
            assert data_path, "File extension must be txt, csv, json or jsonl."
52
            checksum(data_files, dataset_attr.dataset_sha1)
53
        else:
54
            raise NotImplementedError
55

56
        if dataset_attr.load_from == "ms_hub":
57
            try:
58
                from modelscope import MsDataset # type: ignore
59
                from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
60

61
                cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
62
                dataset = MsDataset.load(
63
                    dataset_name=data_path,
64
                    subset_name=data_name,
65
                    data_dir=data_dir,
66
                    data_files=data_files,
67
                    split=data_args.split,
68
                    cache_dir=cache_dir,
69
                    token=model_args.ms_hub_token,
70
                    use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
71
                ).to_hf_dataset()
72
            except ImportError:
73
                raise ImportError("Please install modelscope via `pip install modelscope -U`")
74
        else:
75
            dataset = load_dataset(
76
                path=data_path,
77
                name=data_name,
78
                data_dir=data_dir,
79
                data_files=data_files,
80
                split=data_args.split,
81
                cache_dir=model_args.cache_dir,
82
                token=model_args.hf_hub_token,
83
                streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
84
            )
85

86
        if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
87
            dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
88

89
        if max_samples is not None: # truncate dataset
90
            dataset = dataset.select(range(min(len(dataset), max_samples)))
91

92
        def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
93
            # convert dataset from sharegpt format to alpaca format
94
            outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
95
            for i, msg_list in enumerate(examples[dataset_attr.messages]):
96
                msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
97
                if len(msg_list) == 0:
98
                    continue
99

100
                msg_pairs = []
101
                user_role, assistant_role = None, None
102
                for idx in range(0, len(msg_list), 2):
103
                    if user_role is None and assistant_role is None:
104
                        user_role = msg_list[idx][dataset_attr.role]
105
                        assistant_role = msg_list[idx + 1][dataset_attr.role]
106
                    else:
107
                        if (
108
                            msg_list[idx][dataset_attr.role] != user_role
109
                            or msg_list[idx+1][dataset_attr.role] != assistant_role
110
                        ):
111
                            raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
112
                    msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
113

114
                if len(msg_pairs) != 0:
115
                    outputs["prompt"].append(msg_pairs[-1][0])
116
                    outputs["query"].append("")
117
                    outputs["response"].append(msg_pairs[-1][1])
118
                    outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
119
                    outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
120

121
            return outputs
122

123
        if dataset_attr.formatting == "sharegpt": # convert format
124
            column_names = list(next(iter(dataset)).keys())
125
            kwargs = {}
126
            if not data_args.streaming:
127
                kwargs = dict(
128
                    num_proc=data_args.preprocessing_num_workers,
129
                    load_from_cache_file=(not data_args.overwrite_cache),
130
                    desc="Converting format of dataset"
131
                )
132

133
            dataset = dataset.map(
134
                convert_format,
135
                batched=True,
136
                remove_columns=column_names,
137
                **kwargs
138
            )
139
        else:
140
            for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
141
                if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
142
                    dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
143

144
        all_datasets.append(dataset)
145

146
    if len(data_args.dataset_list) == 1:
147
        return all_datasets[0]
148
    elif data_args.mix_strategy == "concat":
149
        if data_args.streaming:
150
            logger.warning("The samples between different datasets will not be mixed in streaming mode.")
151
        return concatenate_datasets(all_datasets)
152
    elif data_args.mix_strategy.startswith("interleave"):
153
        if not data_args.streaming:
154
            logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
155
        return interleave_datasets(
156
            datasets=all_datasets,
157
            probabilities=data_args.interleave_probs,
158
            seed=data_args.seed,
159
            stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
160
        )
161
    else:
162
        raise ValueError("Unknown mixing strategy.")
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.