lmops

Форк
0
/
format_all_tasks.py 
77 строк · 2.7 Кб
1
import os
2
import sys
3
import json
4
import argparse
5

6
sys.path.insert(0, 'src/')
7

8
from typing import List
9
from datasets import Dataset, concatenate_datasets
10

11
from utils import save_dataset
12
from tasks import task_map, BaseTask
13
from logger_config import logger
14

15
parser = argparse.ArgumentParser(description='data preprocessing for all tasks')
16
parser.add_argument('--output-dir', default='./data/tasks/',
17
                    type=str, metavar='N', help='output directory')
18
parser.add_argument('--template-idx', default=0, type=int, metavar='N',
19
                    help='template index for the task')
20
parser.add_argument('--max-train-examples', default=30_000, type=int, metavar='N',
21
                    help='maximum number of training examples per task')
22

23
args = parser.parse_args()
24
os.makedirs(args.output_dir, exist_ok=True)
25
logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
26

27

28
def format_and_save_corpus():
29
    corpus_list: List[Dataset] = []
30
    for task_name, task_cls in task_map.cls_dic.items():
31
        task: BaseTask = task_cls(template_idx=args.template_idx)
32
        logger.info('Task: {}'.format(task_name))
33
        task_corpus: Dataset = task.get_corpus()
34
        if task_corpus is None:
35
            continue
36

37
        logger.info('Task: {}, corpus size: {}'.format(task_name, len(task_corpus)))
38
        corpus_list.append(task_corpus)
39

40
    corpus: Dataset = concatenate_datasets(corpus_list)
41
    corpus = corpus.add_column('id', [str(i) for i in range(len(corpus))])
42

43
    out_path: str = '{}/passages.jsonl.gz'.format(args.output_dir)
44
    save_dataset(corpus, out_path=out_path)
45
    logger.info('Save {} lines to {}'.format(len(corpus), out_path))
46

47

48
def prepare_split(split: str = 'test'):
49
    dataset_list: List[Dataset] = []
50
    for task_name, task_cls in task_map.cls_dic.items():
51
        task: BaseTask = task_cls(template_idx=args.template_idx)
52
        logger.info('Task: {}'.format(task_name))
53
        task_ds: Dataset = task.get_task_data(split=split)
54
        if task_ds is None:
55
            continue
56

57
        logger.info('Task: {}, size: {}'.format(task_name, len(task_ds)))
58
        if split == 'train' and len(task_ds) > args.max_train_examples:
59
            task_ds = task_ds.shuffle().select(range(args.max_train_examples))
60
            logger.info('Random sample to {} examples'.format(len(task_ds)))
61
        dataset_list.append(task_ds)
62

63
    dataset: Dataset = concatenate_datasets(dataset_list)
64

65
    out_path: str = os.path.join(args.output_dir, '{}.jsonl.gz'.format(split))
66
    save_dataset(dataset, out_path)
67
    logger.info('Save {} examples to {}'.format(len(dataset), out_path))
68

69

70
def main():
71
    format_and_save_corpus()
72
    for split in ['train', 'test']:
73
        prepare_split(split)
74

75

76
if __name__ == '__main__':
77
    main()
78

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.