mteb

Форк
0
/
test_overview.py 
86 строк · 2.9 Кб
1
from __future__ import annotations
2

3
import pytest
4

5
import mteb
6
from mteb import get_tasks
7
from mteb.abstasks.TaskMetadata import TASK_DOMAIN, TASK_TYPE
8
from mteb.overview import MTEBTasks
9

10

11
def test_get_tasks_size_differences():
12
    assert len(get_tasks()) > 0
13
    assert len(get_tasks()) >= len(get_tasks(languages=["eng"]))
14
    assert len(get_tasks()) >= len(get_tasks(script=["Latn"]))
15
    assert len(get_tasks()) >= len(get_tasks(domains=["Legal"]))
16
    assert len(get_tasks()) >= len(get_tasks(languages=["eng", "deu"]))
17
    assert len(get_tasks(languages=["eng", "deu"])) >= len(
18
        get_tasks(languages=["eng", "deu"])
19
    )
20

21

22
@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"], None])
23
@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None])
24
@pytest.mark.parametrize("domains", [["Legal"], ["Medical", "Non-fiction"], None])
25
@pytest.mark.parametrize("task_types", [["Classification"], ["Clustering"], None])
26
@pytest.mark.parametrize("exclude_superseeded_datasets", [True, False])
27
def test_get_task(
28
    languages: list[str],
29
    script: list[str],
30
    domains: list[TASK_DOMAIN],
31
    task_types: list[TASK_TYPE] | None,
32
    exclude_superseeded_datasets: bool,
33
):
34
    tasks = mteb.get_tasks(
35
        languages=languages,
36
        script=script,
37
        domains=domains,
38
        task_types=task_types,
39
        exclude_superseeded=exclude_superseeded_datasets,
40
    )
41

42
    for task in tasks:
43
        if languages:
44
            assert set(languages).intersection(task.metadata.languages)
45
        if script:
46
            assert set(script).intersection(task.metadata.scripts)
47
        if domains:
48
            task_domains = (
49
                set(task.metadata.domains) if task.metadata.domains else set()
50
            )
51
            assert set(domains).intersection(set(task_domains))
52
        if task_types:
53
            assert task.metadata.type in task_types
54
        if exclude_superseeded_datasets:
55
            assert task.superseeded_by is None
56

57

58
def test_get_tasks_filtering():
59
    """Tests that get_tasks filters tasks for languages within the task, i.e. that a multilingual task returns only relevant subtasks for the
60
    specified languages
61
    """
62
    tasks = get_tasks(languages=["eng"])
63

64
    for task in tasks:
65
        if task.is_multilingual:
66
            assert isinstance(task.metadata.eval_langs, dict)
67

68
            for hf_split in task.langs:
69
                assert "eng-Latn" in task.metadata.eval_langs[hf_split]
70

71

72
@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None])
73
@pytest.mark.parametrize("task_types", [["Classification"], ["Clustering"], None])
74
def test_MTEBTasks(
75
    script: list[str],
76
    task_types: list[TASK_TYPE] | None,
77
):
78
    tasks = mteb.get_tasks(script=script, task_types=task_types)
79
    assert isinstance(tasks, MTEBTasks)
80
    langs = tasks.languages
81
    for t in tasks:
82
        assert len(langs.intersection(t.languages)) > 0
83

84
    # check for header of a table
85
    n_langs = len(tasks)
86
    assert len(tasks.to_markdown().split("\n")) - 3 == n_langs
87

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.