llama-index
262 строки · 9.1 Кб
1"""Download."""
2
3import json4import os5from pathlib import Path6from typing import Any, Dict, List, Optional, Union7
8import requests9import tqdm10
11from llama_index.legacy.download.module import LLAMA_HUB_URL12from llama_index.legacy.download.utils import (13get_file_content,14get_file_content_bytes,15initialize_directory,16)
17
18LLAMA_DATASETS_LFS_URL = (19f"https://media.githubusercontent.com/media/run-llama/llama-datasets/main"20)
21
22LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL = (23"https://github.com/run-llama/llama-datasets/tree/main"24)
25LLAMA_SOURCE_FILES_PATH = "source_files"26
27DATASET_CLASS_FILENAME_REGISTRY = {28"LabelledRagDataset": "rag_dataset.json",29"LabeledRagDataset": "rag_dataset.json",30"LabelledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",31"LabeledPairwiseEvaluatorDataset": "pairwise_evaluator_dataset.json",32"LabelledEvaluatorDataset": "evaluator_dataset.json",33"LabeledEvaluatorDataset": "evaluator_dataset.json",34}
35
36
37PATH_TYPE = Union[str, Path]38
39
40def _resolve_dataset_file_name(class_name: str) -> str:41"""Resolve filename based on dataset class."""42try:43return DATASET_CLASS_FILENAME_REGISTRY[class_name]44except KeyError as err:45raise ValueError("Invalid dataset filename.") from err46
47
48def _get_source_files_list(source_tree_url: str, path: str) -> List[str]:49"""Get the list of source files to download."""50resp = requests.get(source_tree_url + path + "?recursive=1")51payload = resp.json()["payload"]52return [item["name"] for item in payload["tree"]["items"]]53
54
55def get_dataset_info(56local_dir_path: PATH_TYPE,57remote_dir_path: PATH_TYPE,58remote_source_dir_path: PATH_TYPE,59dataset_class: str,60refresh_cache: bool = False,61library_path: str = "library.json",62source_files_path: str = "source_files",63disable_library_cache: bool = False,64) -> Dict:65"""Get dataset info."""66if isinstance(local_dir_path, str):67local_dir_path = Path(local_dir_path)68
69local_library_path = f"{local_dir_path}/{library_path}"70dataset_id = None71source_files = []72
73# Check cache first74if not refresh_cache and os.path.exists(local_library_path):75with open(local_library_path) as f:76library = json.load(f)77if dataset_class in library:78dataset_id = library[dataset_class]["id"]79source_files = library[dataset_class].get("source_files", [])80
81# Fetch up-to-date library from remote repo if dataset_id not found82if dataset_id is None:83library_raw_content, _ = get_file_content(84str(remote_dir_path), f"/{library_path}"85)86library = json.loads(library_raw_content)87if dataset_class not in library:88raise ValueError("Loader class name not found in library")89
90dataset_id = library[dataset_class]["id"]91
92# get data card93raw_card_content, _ = get_file_content(94str(remote_dir_path), f"/{dataset_id}/card.json"95)96card = json.loads(raw_card_content)97dataset_class_name = card["className"]98
99source_files = []100if dataset_class_name == "LabelledRagDataset":101source_files = _get_source_files_list(102str(remote_source_dir_path), f"/{dataset_id}/{source_files_path}"103)104
105# create cache dir if needed106local_library_dir = os.path.dirname(local_library_path)107if not disable_library_cache:108if not os.path.exists(local_library_dir):109os.makedirs(local_library_dir)110
111# Update cache112with open(local_library_path, "w") as f:113f.write(library_raw_content)114
115if dataset_id is None:116raise ValueError("Dataset class name not found in library")117
118return {119"dataset_id": dataset_id,120"dataset_class_name": dataset_class_name,121"source_files": source_files,122}123
124
125def download_dataset_and_source_files(126local_dir_path: PATH_TYPE,127remote_lfs_dir_path: PATH_TYPE,128source_files_dir_path: PATH_TYPE,129dataset_id: str,130dataset_class_name: str,131source_files: List[str],132refresh_cache: bool = False,133base_file_name: str = "rag_dataset.json",134override_path: bool = False,135show_progress: bool = False,136) -> None:137"""Download dataset and source files."""138if isinstance(local_dir_path, str):139local_dir_path = Path(local_dir_path)140
141if override_path:142module_path = str(local_dir_path)143else:144module_path = f"{local_dir_path}/{dataset_id}"145
146if refresh_cache or not os.path.exists(module_path):147os.makedirs(module_path, exist_ok=True)148
149base_file_name = _resolve_dataset_file_name(dataset_class_name)150
151dataset_raw_content, _ = get_file_content(152str(remote_lfs_dir_path), f"/{dataset_id}/{base_file_name}"153)154
155with open(f"{module_path}/{base_file_name}", "w") as f:156f.write(dataset_raw_content)157
158# Get content of source files159if dataset_class_name == "LabelledRagDataset":160os.makedirs(f"{module_path}/{source_files_dir_path}", exist_ok=True)161if show_progress:162source_files_iterator = tqdm.tqdm(source_files)163else:164source_files_iterator = source_files165for source_file in source_files_iterator:166if ".pdf" in source_file:167source_file_raw_content_bytes, _ = get_file_content_bytes(168str(remote_lfs_dir_path),169f"/{dataset_id}/{source_files_dir_path}/{source_file}",170)171with open(172f"{module_path}/{source_files_dir_path}/{source_file}", "wb"173) as f:174f.write(source_file_raw_content_bytes)175else:176source_file_raw_content, _ = get_file_content(177str(remote_lfs_dir_path),178f"/{dataset_id}/{source_files_dir_path}/{source_file}",179)180with open(181f"{module_path}/{source_files_dir_path}/{source_file}", "w"182) as f:183f.write(source_file_raw_content)184
185
186def download_llama_dataset(187dataset_class: str,188llama_hub_url: str = LLAMA_HUB_URL,189llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL,190llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL,191refresh_cache: bool = False,192custom_dir: Optional[str] = None,193custom_path: Optional[str] = None,194source_files_dirpath: str = LLAMA_SOURCE_FILES_PATH,195library_path: str = "llama_datasets/library.json",196disable_library_cache: bool = False,197override_path: bool = False,198show_progress: bool = False,199) -> Any:200"""201Download a module from LlamaHub.
202
203Can be a loader, tool, pack, or more.
204
205Args:
206loader_class: The name of the llama module class you want to download,
207such as `GmailOpenAIAgentPack`.
208refresh_cache: If true, the local cache will be skipped and the
209loader will be fetched directly from the remote repo.
210custom_dir: Custom dir name to download loader into (under parent folder).
211custom_path: Custom dirpath to download loader into.
212library_path: File name of the library file.
213use_gpt_index_import: If true, the loader files will use
214llama_index as the base dependency. By default (False),
215the loader files use llama_index as the base dependency.
216NOTE: this is a temporary workaround while we fully migrate all usages
217to llama_index.
218is_dataset: whether or not downloading a LlamaDataset
219
220Returns:
221A Loader, A Pack, An Agent, or A Dataset
222"""
223# create directory / get path224dirpath = initialize_directory(custom_path=custom_path, custom_dir=custom_dir)225
226# fetch info from library.json file227dataset_info = get_dataset_info(228local_dir_path=dirpath,229remote_dir_path=llama_hub_url,230remote_source_dir_path=llama_datasets_source_files_tree_url,231dataset_class=dataset_class,232refresh_cache=refresh_cache,233library_path=library_path,234disable_library_cache=disable_library_cache,235)236dataset_id = dataset_info["dataset_id"]237source_files = dataset_info["source_files"]238dataset_class_name = dataset_info["dataset_class_name"]239
240dataset_filename = _resolve_dataset_file_name(dataset_class_name)241
242download_dataset_and_source_files(243local_dir_path=dirpath,244remote_lfs_dir_path=llama_datasets_lfs_url,245source_files_dir_path=source_files_dirpath,246dataset_id=dataset_id,247dataset_class_name=dataset_class_name,248source_files=source_files,249refresh_cache=refresh_cache,250override_path=override_path,251show_progress=show_progress,252)253
254if override_path:255module_path = str(dirpath)256else:257module_path = f"{dirpath}/{dataset_id}"258
259return (260f"{module_path}/{dataset_filename}",261f"{module_path}/{LLAMA_SOURCE_FILES_PATH}",262)263