FEDOT

Форк
0
/
run_pipeline.py 
95 строк · 2.8 Кб
1
import json
2
import os
3
import sys
4
from typing import Union
5

6
from golem.core.log import default_log
7
from golem.utilities.random import RandomStateHandler
8

9
from fedot.core.data.data import InputData
10
from fedot.core.data.multi_modal import MultiModalData
11
from fedot.core.pipelines.pipeline import Pipeline
12
from fedot.core.pipelines.verification import verifier_for_task
13
from fedot.core.repository.dataset_types import DataTypesEnum
14
from fedot.core.repository.tasks import TaskTypesEnum
15
from fedot.remote.pipeline_run_config import PipelineRunConfig
16

17

18
def new_key_name(data_part_key):
19
    if data_part_key == 'idx':
20
        return 'idx'
21
    return f'data_source_ts/{data_part_key}'
22

23

24
def _load_ts_data(config):
25
    task = config.task
26
    if config.is_multi_modal:
27
        train_data = MultiModalData.from_csv_time_series(
28
            file_path=config.input_data,
29
            task=task, target_column=config.target,
30
            columns_to_use=config.var_names)
31
    else:
32
        train_data = InputData.from_csv_time_series(
33
            file_path=config.input_data,
34
            task=task, target_column=config.target)
35
    return train_data
36

37

38
def _load_data(config):
39
    data_type = DataTypesEnum.table
40
    if config.task.task_type == TaskTypesEnum.ts_forecasting:
41
        train_data = _load_ts_data(config)
42
    else:
43
        train_data = InputData.from_csv(file_path=config.input_data,
44
                                        task=config.task, data_type=data_type)
45
    return train_data
46

47

48
def fit_pipeline(config_file: Union[str, bytes], save_pipeline: bool = True) -> bool:
49
    logger = default_log(prefix='pipeline_fitting_logger')
50

51
    config = \
52
        PipelineRunConfig().load_from_file(config_file)
53

54
    verifier = verifier_for_task(config.task.task_type)
55

56
    pipeline = pipeline_from_json(config.pipeline_template)
57

58
    train_data = _load_data(config)
59

60
    # subset data using indices
61
    if config.train_data_idx not in [None, []]:
62
        train_data = train_data.subset_indices(config.train_data_idx)
63

64
    if not verifier(pipeline):
65
        logger.error('Pipeline not valid')
66
        return False
67

68
    try:
69
        RandomStateHandler.MODEL_FITTING_SEED = 0
70
        pipeline.fit_from_scratch(train_data)
71
    except Exception as ex:
72
        logger.error(ex)
73
        return False
74

75
    if config.test_data_path:
76
        test_data = InputData.from_csv(config.test_data_path)
77
        pipeline.predict(test_data)
78

79
    if save_pipeline:
80
        pipeline.save(path=os.path.join(config.output_path, 'fitted_pipeline'), create_subdir=False,
81
                      is_datetime_in_path=False)
82

83
    return True
84

85

86
def pipeline_from_json(json_str: str):
87
    json_dict = json.loads(json_str)
88
    pipeline = Pipeline.from_serialized(json_dict)
89

90
    return pipeline
91

92

93
if __name__ == '__main__':
94
    config_file_name_from_argv = sys.argv[1]
95
    fit_pipeline(config_file_name_from_argv)
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.