OpenDelta

Форк
0
102 строки · 4.1 Кб
1
import abc
2
from typing import Callable, List, Mapping, Dict
3
import datasets
4
import logging
5
import numpy as np
6
import torch
7
logger = logging.getLogger(__name__)
8

9

10
class AbstractTask(abc.ABC):
11
    name = NotImplemented
12
    config = NotImplemented
13
    prefix = NotImplemented
14
    metric = NotImplemented
15
    metric_names = NotImplemented
16
    split_map = None
17
    labels_list = None
18
    split_to_data_split: Mapping[str, str] = \
19
        {"train": "train", "validation": "validation", "test": "test"}
20
    split_valid_to_make_test = True
21
    split_train_to_make_test = False
22
    keep_fields_after_preprocess = ["label"]  # The fields that should be kept even after preprocessiing
23

24
    def __init__(self, config, data_args, seed=42, default_max_length=1):
25
        self.config = config
26
        self.seed = seed
27
        self.data_args = data_args
28

29
        self.default_max_length = default_max_length
30

31
    def check_n_obs(self, n_obs, total_size):
32
        if n_obs is not None and n_obs > total_size:
33
            n_obs = total_size
34
            logger.warning("n_obs is set to %s", n_obs)
35
        return n_obs
36

37
    def shuffled_indices(self, dataset):
38
        num_samples = len(dataset)
39
        generator = torch.Generator()
40
        generator.manual_seed(self.seed)
41
        return torch.randperm(num_samples, generator=generator).tolist()
42

43
    def subsample(self, dataset, n_obs=None, indices=None):
44
        """
45
        Given a dataset returns the subsampled dataset.
46
        :param n_obs: the number of samples of the subsampled dataset.
47
        :param indices: indices to select the samples from, if not given, indices are computed
48
        from by shuffling the given dataset.
49
        :return: subsampled dataset.
50
        """
51
        num_samples = len(dataset)
52
        n_obs = self.check_n_obs(n_obs, num_samples)
53
        if indices is None:
54
           indices = self.shuffled_indices(dataset)
55
        indices = indices[:n_obs]
56
        return dataset.select(indices)
57

58
    def load_dataset(self, split: int):
59
        return datasets.load_dataset(self.name, self.config, split=split, script_version="master")
60

61
    def get_split_indices(self, split, dataset, validation_size):
62
        indices = self.shuffled_indices(dataset)
63
        if split == "validation":
64
            return indices[:validation_size]
65
        else:
66
            return indices[validation_size:]
67

68
    def preprocessor(self, example):
69
        return example
70

71
    def get(self, split, n_obs=None, split_validation_test=False):
72
        # For small datasets (n_samples < 10K) without test set, we divide validation set to
73
        # half, use one half as test set and one half as validation set.
74
        if split in ["eval", "dev", "valid"]:
75
            split = "validation"
76
        if split_validation_test and self.split_valid_to_make_test \
77
                and split != "train":
78
            mapped_split = self.split_to_data_split["validation"]
79
            dataset = self.load_dataset(split=mapped_split)
80
            indices = self.get_split_indices(split, dataset, validation_size=len(dataset)//2)
81
            dataset = self.subsample(dataset, n_obs, indices)
82
        # For larger datasets (n_samples > 10K), we divide training set into 1K as
83
        # validation and the rest as training set, keeping the original validation
84
        # set as the test set.
85
        elif split_validation_test and self.split_train_to_make_test \
86
                and split != "test":
87
            dataset = self.load_dataset(split="train")
88
            indices = self.get_split_indices(split, dataset, validation_size=1000)
89
            dataset = self.subsample(dataset, n_obs, indices)
90
        else:
91
            mapped_split = self.split_to_data_split[split]
92
            dataset = self.load_dataset(split=mapped_split)
93
            # shuffles the data and samples it.
94
            if n_obs is not None:
95
                dataset = self.subsample(dataset, n_obs)
96

97
        this_method = getattr(self.__class__, 'preprocessor')
98
        base_method = getattr(AbstractTask, 'preprocessor')
99
        if this_method is not base_method:
100
            return dataset.map(self.preprocessor)
101
        else:
102
            return dataset
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.