OpenDelta
102 строки · 4.1 Кб
1import abc
2from typing import Callable, List, Mapping, Dict
3import datasets
4import logging
5import numpy as np
6import torch
7logger = logging.getLogger(__name__)
8
9
10class AbstractTask(abc.ABC):
11name = NotImplemented
12config = NotImplemented
13prefix = NotImplemented
14metric = NotImplemented
15metric_names = NotImplemented
16split_map = None
17labels_list = None
18split_to_data_split: Mapping[str, str] = \
19{"train": "train", "validation": "validation", "test": "test"}
20split_valid_to_make_test = True
21split_train_to_make_test = False
22keep_fields_after_preprocess = ["label"] # The fields that should be kept even after preprocessiing
23
24def __init__(self, config, data_args, seed=42, default_max_length=1):
25self.config = config
26self.seed = seed
27self.data_args = data_args
28
29self.default_max_length = default_max_length
30
31def check_n_obs(self, n_obs, total_size):
32if n_obs is not None and n_obs > total_size:
33n_obs = total_size
34logger.warning("n_obs is set to %s", n_obs)
35return n_obs
36
37def shuffled_indices(self, dataset):
38num_samples = len(dataset)
39generator = torch.Generator()
40generator.manual_seed(self.seed)
41return torch.randperm(num_samples, generator=generator).tolist()
42
43def subsample(self, dataset, n_obs=None, indices=None):
44"""
45Given a dataset returns the subsampled dataset.
46:param n_obs: the number of samples of the subsampled dataset.
47:param indices: indices to select the samples from, if not given, indices are computed
48from by shuffling the given dataset.
49:return: subsampled dataset.
50"""
51num_samples = len(dataset)
52n_obs = self.check_n_obs(n_obs, num_samples)
53if indices is None:
54indices = self.shuffled_indices(dataset)
55indices = indices[:n_obs]
56return dataset.select(indices)
57
58def load_dataset(self, split: int):
59return datasets.load_dataset(self.name, self.config, split=split, script_version="master")
60
61def get_split_indices(self, split, dataset, validation_size):
62indices = self.shuffled_indices(dataset)
63if split == "validation":
64return indices[:validation_size]
65else:
66return indices[validation_size:]
67
68def preprocessor(self, example):
69return example
70
71def get(self, split, n_obs=None, split_validation_test=False):
72# For small datasets (n_samples < 10K) without test set, we divide validation set to
73# half, use one half as test set and one half as validation set.
74if split in ["eval", "dev", "valid"]:
75split = "validation"
76if split_validation_test and self.split_valid_to_make_test \
77and split != "train":
78mapped_split = self.split_to_data_split["validation"]
79dataset = self.load_dataset(split=mapped_split)
80indices = self.get_split_indices(split, dataset, validation_size=len(dataset)//2)
81dataset = self.subsample(dataset, n_obs, indices)
82# For larger datasets (n_samples > 10K), we divide training set into 1K as
83# validation and the rest as training set, keeping the original validation
84# set as the test set.
85elif split_validation_test and self.split_train_to_make_test \
86and split != "test":
87dataset = self.load_dataset(split="train")
88indices = self.get_split_indices(split, dataset, validation_size=1000)
89dataset = self.subsample(dataset, n_obs, indices)
90else:
91mapped_split = self.split_to_data_split[split]
92dataset = self.load_dataset(split=mapped_split)
93# shuffles the data and samples it.
94if n_obs is not None:
95dataset = self.subsample(dataset, n_obs)
96
97this_method = getattr(self.__class__, 'preprocessor')
98base_method = getattr(AbstractTask, 'preprocessor')
99if this_method is not base_method:
100return dataset.map(self.preprocessor)
101else:
102return dataset
103