lmops
77 строк · 2.7 Кб
1import os
2import sys
3import json
4import argparse
5
6sys.path.insert(0, 'src/')
7
8from typing import List
9from datasets import Dataset, concatenate_datasets
10
11from utils import save_dataset
12from tasks import task_map, BaseTask
13from logger_config import logger
14
15parser = argparse.ArgumentParser(description='data preprocessing for all tasks')
16parser.add_argument('--output-dir', default='./data/tasks/',
17type=str, metavar='N', help='output directory')
18parser.add_argument('--template-idx', default=0, type=int, metavar='N',
19help='template index for the task')
20parser.add_argument('--max-train-examples', default=30_000, type=int, metavar='N',
21help='maximum number of training examples per task')
22
23args = parser.parse_args()
24os.makedirs(args.output_dir, exist_ok=True)
25logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4)))
26
27
28def format_and_save_corpus():
29corpus_list: List[Dataset] = []
30for task_name, task_cls in task_map.cls_dic.items():
31task: BaseTask = task_cls(template_idx=args.template_idx)
32logger.info('Task: {}'.format(task_name))
33task_corpus: Dataset = task.get_corpus()
34if task_corpus is None:
35continue
36
37logger.info('Task: {}, corpus size: {}'.format(task_name, len(task_corpus)))
38corpus_list.append(task_corpus)
39
40corpus: Dataset = concatenate_datasets(corpus_list)
41corpus = corpus.add_column('id', [str(i) for i in range(len(corpus))])
42
43out_path: str = '{}/passages.jsonl.gz'.format(args.output_dir)
44save_dataset(corpus, out_path=out_path)
45logger.info('Save {} lines to {}'.format(len(corpus), out_path))
46
47
48def prepare_split(split: str = 'test'):
49dataset_list: List[Dataset] = []
50for task_name, task_cls in task_map.cls_dic.items():
51task: BaseTask = task_cls(template_idx=args.template_idx)
52logger.info('Task: {}'.format(task_name))
53task_ds: Dataset = task.get_task_data(split=split)
54if task_ds is None:
55continue
56
57logger.info('Task: {}, size: {}'.format(task_name, len(task_ds)))
58if split == 'train' and len(task_ds) > args.max_train_examples:
59task_ds = task_ds.shuffle().select(range(args.max_train_examples))
60logger.info('Random sample to {} examples'.format(len(task_ds)))
61dataset_list.append(task_ds)
62
63dataset: Dataset = concatenate_datasets(dataset_list)
64
65out_path: str = os.path.join(args.output_dir, '{}.jsonl.gz'.format(split))
66save_dataset(dataset, out_path)
67logger.info('Save {} examples to {}'.format(len(dataset), out_path))
68
69
70def main():
71format_and_save_corpus()
72for split in ['train', 'test']:
73prepare_split(split)
74
75
76if __name__ == '__main__':
77main()
78