mteb

Форк
0
/
test_all_abstasks.py 
78 строк · 2.5 Кб
1
from __future__ import annotations
2

3
import asyncio
4
import logging
5
from unittest.mock import Mock, patch
6

7
import aiohttp
8
import pytest
9

10
from mteb import MTEB
11
from mteb.abstasks import AbsTask
12
from mteb.abstasks.AbsTaskInstructionRetrieval import AbsTaskInstructionRetrieval
13
from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval
14
from mteb.abstasks.MultiSubsetLoader import MultiSubsetLoader
15

16
logging.basicConfig(level=logging.INFO)
17

18

19
@pytest.mark.parametrize("task", MTEB().tasks_cls)
20
@patch("datasets.load_dataset")
21
@patch("datasets.concatenate_datasets")
22
def test_load_data(
23
    mock_concatenate_datasets: Mock, mock_load_dataset: Mock, task: AbsTask
24
):
25
    # TODO: We skip because this load_data is completely different.
26
    if (
27
        isinstance(task, AbsTaskRetrieval)
28
        or isinstance(task, AbsTaskInstructionRetrieval)
29
        or isinstance(task, MultiSubsetLoader)
30
    ):
31
        pytest.skip()
32
    with patch.object(task, "dataset_transform") as mock_dataset_transform:
33
        task.load_data()
34
        mock_load_dataset.assert_called()
35

36
        # They don't yet but should they so they can be expanded more easily?
37
        if not task.is_crosslingual and not task.is_multilingual:
38
            mock_dataset_transform.assert_called_once()
39

40

41
async def check_dataset_on_hf(
42
    session: aiohttp.ClientSession, dataset: str, revision: str
43
) -> bool:
44
    url = f"https://huggingface.co/datasets/{dataset}/tree/{revision}"
45
    async with session.head(url) as response:
46
        return response.status == 200
47

48

49
async def check_datasets_are_available_on_hf(tasks):
50
    does_not_exist = []
51
    async with aiohttp.ClientSession() as session:
52
        tasks_checks = [
53
            check_dataset_on_hf(
54
                session,
55
                task.metadata.dataset["path"],
56
                task.metadata.dataset["revision"],
57
            )
58
            for task in tasks
59
        ]
60
        datasets_exists = await asyncio.gather(*tasks_checks)
61

62
    for task, ds_exists in zip(tasks, datasets_exists):
63
        if not ds_exists:
64
            does_not_exist.append(
65
                (task.metadata.dataset["path"], task.metadata.dataset["revision"])
66
            )
67

68
    if does_not_exist:
69
        pretty_print = "\n".join(
70
            [f"{ds[0]} - revision {ds[1]}" for ds in does_not_exist]
71
        )
72
        assert False, f"Datasets not available on Hugging Face:\n{pretty_print}"
73

74

75
def test_dataset_availability():
76
    """Checks if the datasets are available on Hugging Face using both their name and revision."""
77
    tasks = MTEB().tasks_cls
78
    asyncio.run(check_datasets_are_available_on_hf(tasks))
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.