4
from typing import Union
6
from golem.core.log import default_log
7
from golem.utilities.random import RandomStateHandler
9
from fedot.core.data.data import InputData
10
from fedot.core.data.multi_modal import MultiModalData
11
from fedot.core.pipelines.pipeline import Pipeline
12
from fedot.core.pipelines.verification import verifier_for_task
13
from fedot.core.repository.dataset_types import DataTypesEnum
14
from fedot.core.repository.tasks import TaskTypesEnum
15
from fedot.remote.pipeline_run_config import PipelineRunConfig
18
def new_key_name(data_part_key):
19
if data_part_key == 'idx':
21
return f'data_source_ts/{data_part_key}'
24
def _load_ts_data(config):
26
if config.is_multi_modal:
27
train_data = MultiModalData.from_csv_time_series(
28
file_path=config.input_data,
29
task=task, target_column=config.target,
30
columns_to_use=config.var_names)
32
train_data = InputData.from_csv_time_series(
33
file_path=config.input_data,
34
task=task, target_column=config.target)
38
def _load_data(config):
39
data_type = DataTypesEnum.table
40
if config.task.task_type == TaskTypesEnum.ts_forecasting:
41
train_data = _load_ts_data(config)
43
train_data = InputData.from_csv(file_path=config.input_data,
44
task=config.task, data_type=data_type)
48
def fit_pipeline(config_file: Union[str, bytes], save_pipeline: bool = True) -> bool:
49
logger = default_log(prefix='pipeline_fitting_logger')
52
PipelineRunConfig().load_from_file(config_file)
54
verifier = verifier_for_task(config.task.task_type)
56
pipeline = pipeline_from_json(config.pipeline_template)
58
train_data = _load_data(config)
60
# subset data using indices
61
if config.train_data_idx not in [None, []]:
62
train_data = train_data.subset_indices(config.train_data_idx)
64
if not verifier(pipeline):
65
logger.error('Pipeline not valid')
69
RandomStateHandler.MODEL_FITTING_SEED = 0
70
pipeline.fit_from_scratch(train_data)
71
except Exception as ex:
75
if config.test_data_path:
76
test_data = InputData.from_csv(config.test_data_path)
77
pipeline.predict(test_data)
80
pipeline.save(path=os.path.join(config.output_path, 'fitted_pipeline'), create_subdir=False,
81
is_datetime_in_path=False)
86
def pipeline_from_json(json_str: str):
87
json_dict = json.loads(json_str)
88
pipeline = Pipeline.from_serialized(json_dict)
93
if __name__ == '__main__':
94
config_file_name_from_argv = sys.argv[1]
95
fit_pipeline(config_file_name_from_argv)