llama-index

Форк
0
262 строки · 9.1 Кб
1
"""Download."""
2

3
import json
4
import os
5
from pathlib import Path
6
from typing import Any, Dict, List, Optional, Union
7

8
import requests
9
import tqdm
10

11
from llama_index.legacy.download.module import LLAMA_HUB_URL
12
from llama_index.legacy.download.utils import (
13
    get_file_content,
14
    get_file_content_bytes,
15
    initialize_directory,
16
)
17

18
LLAMA_DATASETS_LFS_URL = (
19
    f"https://media.githubusercontent.com/media/run-llama/llama-datasets/main"
20
)
21

22
LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = (
23
    "https://github.com/run-llama/llama-datasets/tree/main"
24
)
25
LLAMA_SOURCE_FILES_PATH = "source_files"
26

27
DATASET_CLASS_FILENAME_REGISTRY = {
28
    "LabelledRagDataset": "rag_dataset.json",
29
    "LabeledRagDataset": "rag_dataset.json",
30
    "LabelledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
31
    "LabeledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",
32
    "LabelledEvaluatorDataset": "evaluator_dataset.json",
33
    "LabeledEvaluatorDataset": "evaluator_dataset.json",
34
}
35

36

37
PATH_TYPE = Union[str, Path]
38

39

40
def _resolve_dataset_file_name(class_name: str) -> str:
41
    """Resolve filename based on dataset class."""
42
    try:
43
        return DATASET_CLASS_FILENAME_REGISTRY[class_name]
44
    except KeyError as err:
45
        raise ValueError("Invalid dataset filename.") from err
46

47

48
def _get_source_files_list(source_tree_url: str, path: str) -> List[str]:
49
    """Get the list of source files to download."""
50
    resp = requests.get(source_tree_url + path + "?recursive=1")
51
    payload = resp.json()["payload"]
52
    return [item["name"] for item in payload["tree"]["items"]]
53

54

55
def get_dataset_info(
56
    local_dir_path: PATH_TYPE,
57
    remote_dir_path: PATH_TYPE,
58
    remote_source_dir_path: PATH_TYPE,
59
    dataset_class: str,
60
    refresh_cache: bool = False,
61
    library_path: str = "library.json",
62
    source_files_path: str = "source_files",
63
    disable_library_cache: bool = False,
64
) -> Dict:
65
    """Get dataset info."""
66
    if isinstance(local_dir_path, str):
67
        local_dir_path = Path(local_dir_path)
68

69
    local_library_path = f"{local_dir_path}/{library_path}"
70
    dataset_id = None
71
    source_files = []
72

73
    # Check cache first
74
    if not refresh_cache and os.path.exists(local_library_path):
75
        with open(local_library_path) as f:
76
            library = json.load(f)
77
        if dataset_class in library:
78
            dataset_id = library[dataset_class]["id"]
79
            source_files = library[dataset_class].get("source_files", [])
80

81
    # Fetch up-to-date library from remote repo if dataset_id not found
82
    if dataset_id is None:
83
        library_raw_content, _ = get_file_content(
84
            str(remote_dir_path), f"/{library_path}"
85
        )
86
        library = json.loads(library_raw_content)
87
        if dataset_class not in library:
88
            raise ValueError("Loader class name not found in library")
89

90
        dataset_id = library[dataset_class]["id"]
91

92
        # get data card
93
        raw_card_content, _ = get_file_content(
94
            str(remote_dir_path), f"/{dataset_id}/card.json"
95
        )
96
        card = json.loads(raw_card_content)
97
        dataset_class_name = card["className"]
98

99
        source_files = []
100
        if dataset_class_name == "LabelledRagDataset":
101
            source_files = _get_source_files_list(
102
                str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}"
103
            )
104

105
        # create cache dir if needed
106
        local_library_dir = os.path.dirname(local_library_path)
107
        if not disable_library_cache:
108
            if not os.path.exists(local_library_dir):
109
                os.makedirs(local_library_dir)
110

111
            # Update cache
112
            with open(local_library_path, "w") as f:
113
                f.write(library_raw_content)
114

115
    if dataset_id is None:
116
        raise ValueError("Dataset class name not found in library")
117

118
    return {
119
        "dataset_id": dataset_id,
120
        "dataset_class_name": dataset_class_name,
121
        "source_files": source_files,
122
    }
123

124

125
def download_dataset_and_source_files(
126
    local_dir_path: PATH_TYPE,
127
    remote_lfs_dir_path: PATH_TYPE,
128
    source_files_dir_path: PATH_TYPE,
129
    dataset_id: str,
130
    dataset_class_name: str,
131
    source_files: List[str],
132
    refresh_cache: bool = False,
133
    base_file_name: str = "rag_dataset.json",
134
    override_path: bool = False,
135
    show_progress: bool = False,
136
) -> None:
137
    """Download dataset and source files."""
138
    if isinstance(local_dir_path, str):
139
        local_dir_path = Path(local_dir_path)
140

141
    if override_path:
142
        module_path = str(local_dir_path)
143
    else:
144
        module_path = f"{local_dir_path}/{dataset_id}"
145

146
    if refresh_cache or not os.path.exists(module_path):
147
        os.makedirs(module_path, exist_ok=True)
148

149
        base_file_name = _resolve_dataset_file_name(dataset_class_name)
150

151
        dataset_raw_content, _ = get_file_content(
152
            str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}"
153
        )
154

155
        with open(f"{module_path}/{base_file_name}", "w") as f:
156
            f.write(dataset_raw_content)
157

158
        # Get content of source files
159
        if dataset_class_name == "LabelledRagDataset":
160
            os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True)
161
            if show_progress:
162
                source_files_iterator = tqdm.tqdm(source_files)
163
            else:
164
                source_files_iterator = source_files
165
            for source_file in source_files_iterator:
166
                if ".pdf" in source_file:
167
                    source_file_raw_content_bytes, _ = get_file_content_bytes(
168
                        str(remote_lfs_dir_path),
169
                        f"/{dataset_id}/{source_files_dir_path}/{source_file}",
170
                    )
171
                    with open(
172
                        f"{module_path}/{source_files_dir_path}/{source_file}", "wb"
173
                    ) as f:
174
                        f.write(source_file_raw_content_bytes)
175
                else:
176
                    source_file_raw_content, _ = get_file_content(
177
                        str(remote_lfs_dir_path),
178
                        f"/{dataset_id}/{source_files_dir_path}/{source_file}",
179
                    )
180
                    with open(
181
                        f"{module_path}/{source_files_dir_path}/{source_file}", "w"
182
                    ) as f:
183
                        f.write(source_file_raw_content)
184

185

186
def download_llama_dataset(
187
    dataset_class: str,
188
    llama_hub_url: str = LLAMA_HUB_URL,
189
    llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,
190
    llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,
191
    refresh_cache: bool = False,
192
    custom_dir: Optional[str] = None,
193
    custom_path: Optional[str] = None,
194
    source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH,
195
    library_path: str = "llama_datasets/library.json",
196
    disable_library_cache: bool = False,
197
    override_path: bool = False,
198
    show_progress: bool = False,
199
) -> Any:
200
    """
201
    Download a module from LlamaHub.
202

203
    Can be a loader, tool, pack, or more.
204

205
    Args:
206
        loader_class: The name of the llama module class you want to download,
207
            such as `GmailOpenAIAgentPack`.
208
        refresh_cache: If true, the local cache will be skipped and the
209
            loader will be fetched directly from the remote repo.
210
        custom_dir: Custom dir name to download loader into (under parent folder).
211
        custom_path: Custom dirpath to download loader into.
212
        library_path: File name of the library file.
213
        use_gpt_index_import: If true, the loader files will use
214
            llama_index as the base dependency. By default (False),
215
            the loader files use llama_index as the base dependency.
216
            NOTE: this is a temporary workaround while we fully migrate all usages
217
            to llama_index.
218
        is_dataset: whether or not downloading a LlamaDataset
219

220
    Returns:
221
        A Loader, A Pack, An Agent, or A Dataset
222
    """
223
    # create directory / get path
224
    dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)
225

226
    # fetch info from library.json file
227
    dataset_info = get_dataset_info(
228
        local_dir_path=dirpath,
229
        remote_dir_path=llama_hub_url,
230
        remote_source_dir_path=llama_datasets_source_files_tree_url,
231
        dataset_class=dataset_class,
232
        refresh_cache=refresh_cache,
233
        library_path=library_path,
234
        disable_library_cache=disable_library_cache,
235
    )
236
    dataset_id = dataset_info["dataset_id"]
237
    source_files = dataset_info["source_files"]
238
    dataset_class_name = dataset_info["dataset_class_name"]
239

240
    dataset_filename = _resolve_dataset_file_name(dataset_class_name)
241

242
    download_dataset_and_source_files(
243
        local_dir_path=dirpath,
244
        remote_lfs_dir_path=llama_datasets_lfs_url,
245
        source_files_dir_path=source_files_dirpath,
246
        dataset_id=dataset_id,
247
        dataset_class_name=dataset_class_name,
248
        source_files=source_files,
249
        refresh_cache=refresh_cache,
250
        override_path=override_path,
251
        show_progress=show_progress,
252
    )
253

254
    if override_path:
255
        module_path = str(dirpath)
256
    else:
257
        module_path = f"{dirpath}/{dataset_id}"
258

259
    return (
260
        f"{module_path}/{dataset_filename}",
261
        f"{module_path}/{LLAMA_SOURCE_FILES_PATH}",
262
    )
263

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.