1
from __future__ import annotations
6
from mteb import get_tasks
7
from mteb.abstasks.TaskMetadata import TASK_DOMAIN, TASK_TYPE
8
from mteb.overview import MTEBTasks
11
def test_get_tasks_size_differences():
12
assert len(get_tasks()) > 0
13
assert len(get_tasks()) >= len(get_tasks(languages=["eng"]))
14
assert len(get_tasks()) >= len(get_tasks(script=["Latn"]))
15
assert len(get_tasks()) >= len(get_tasks(domains=["Legal"]))
16
assert len(get_tasks()) >= len(get_tasks(languages=["eng", "deu"]))
17
assert len(get_tasks(languages=["eng", "deu"])) >= len(
18
get_tasks(languages=["eng", "deu"])
22
@pytest.mark.parametrize("languages", [["eng", "deu"], ["eng"], None])
23
@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None])
24
@pytest.mark.parametrize("domains", [["Legal"], ["Medical", "Non-fiction"], None])
25
@pytest.mark.parametrize("task_types", [["Classification"], ["Clustering"], None])
26
@pytest.mark.parametrize("exclude_superseeded_datasets", [True, False])
30
domains: list[TASK_DOMAIN],
31
task_types: list[TASK_TYPE] | None,
32
exclude_superseeded_datasets: bool,
34
tasks = mteb.get_tasks(
38
task_types=task_types,
39
exclude_superseeded=exclude_superseeded_datasets,
44
assert set(languages).intersection(task.metadata.languages)
46
assert set(script).intersection(task.metadata.scripts)
49
set(task.metadata.domains) if task.metadata.domains else set()
51
assert set(domains).intersection(set(task_domains))
53
assert task.metadata.type in task_types
54
if exclude_superseeded_datasets:
55
assert task.superseeded_by is None
58
def test_get_tasks_filtering():
59
"""Tests that get_tasks filters tasks for languages within the task, i.e. that a multilingual task returns only relevant subtasks for the
62
tasks = get_tasks(languages=["eng"])
65
if task.is_multilingual:
66
assert isinstance(task.metadata.eval_langs, dict)
68
for hf_split in task.langs:
69
assert "eng-Latn" in task.metadata.eval_langs[hf_split]
72
@pytest.mark.parametrize("script", [["Latn"], ["Cyrl"], None])
73
@pytest.mark.parametrize("task_types", [["Classification"], ["Clustering"], None])
76
task_types: list[TASK_TYPE] | None,
78
tasks = mteb.get_tasks(script=script, task_types=task_types)
79
assert isinstance(tasks, MTEBTasks)
80
langs = tasks.languages
82
assert len(langs.intersection(t.languages)) > 0
86
assert len(tasks.to_markdown().split("\n")) - 3 == n_langs