FEDOT

Форк
0
/
datamall_client.py 
243 строки · 9.2 Кб
1
import json
2
import os
3
import shutil
4
import time
5
from datetime import datetime, timedelta
6
from typing import List, Optional
7

8
import requests
9

10
from fedot.core.pipelines.pipeline import Pipeline
11
from fedot.remote.infrastructure.clients.client import Client
12

13
USER_TOKEN_KEY = 'x-jwt-auth'
14
GROUP_TOKEN_KEY = 'x-jwt-models-execution'
15

16
DEFAULT_EXEC_PARAMS = {
17
    'container_input_path': "/home/FEDOT/input_data_dir",
18
    'container_output_path': "/home/FEDOT/output_data_dir",
19
    'container_config_path': "/home/FEDOT/.config",
20
    'container_image': "fedot:dm-9",
21
    'timeout': 360
22
}
23

24
# example of connection params for DataMall
25
DEFAULT_CONNECT_PARAMS = {
26
    'FEDOT_LOGIN': 'fedot',
27
    'FEDOT_PASSWORD': 'fedot-password',
28
    'AUTH_SERVER': 'http://10.32.0.51:30880/b',
29
    'CONTR_SERVER': 'http://10.32.0.51:30880/models-controller',
30
    'PROJECT_ID': '83',
31
    'DATA_ID': '60'
32
}
33

34

35
# TO BE MOVED TO PYPI AS EXTERNAL LIB
36

37
class DataMallClient(Client):
38
    def __init__(self, connect_params: dict, exec_params: dict, output_path: Optional[str] = None):
39
        authorization_server = connect_params['AUTH_SERVER']
40
        controller_server = connect_params['CONTR_SERVER']
41
        self.authorization_server = os.environ['AUTH_SERVER'] if authorization_server is None else authorization_server
42
        self.controller_server = os.environ['CONTR_SERVER'] if controller_server is None else controller_server
43
        self._user_token = None
44
        self.user = None
45
        self._current_project_id = None
46
        self.group_token = None
47

48
        self._login(login=connect_params['FEDOT_LOGIN'],
49
                    password=connect_params['FEDOT_PASSWORD'])
50

51
        pid = connect_params['PROJECT_ID']
52
        group = self._create_execution_group(project_id=pid)
53
        self._set_group_token(project_id=pid, group_id=group['id'])
54

55
        super().__init__(connect_params, exec_params, output_path)
56

57
    def create_task(self, config) -> str:
58
        data_id = self.connect_params['DATA_ID']
59
        created_ex = self._create_execution(f"{self.exec_params['container_input_path']}",
60
                                            self.exec_params['container_output_path'],
61
                                            self.exec_params['container_config_path'],
62
                                            self.exec_params['container_image'],
63
                                            self.exec_params['timeout'],
64
                                            config=config)
65
        return created_ex['id']
66

67
    def wait_until_ready(self) -> timedelta:
68
        statuses = ['']
69
        all_executions = self._get_executions()
70
        self._logger.info(all_executions)
71
        start = datetime.now()
72
        while any(s not in ['Succeeded', 'Failed', 'Timeout', 'Interrupted'] for s in statuses):
73
            executions = self._get_executions()
74
            statuses = [execution['status'] for execution in executions]
75
            self._logger.info([f"{execution['id']}={execution['status']};" for execution in executions])
76
            time.sleep(5)
77
        end = datetime.now()
78
        ex_time = end - start
79
        return ex_time
80

81
    def _login(self, login: str, password: str) -> None:
82
        response = requests.post(
83
            url=f'{self.authorization_server}/users/login',
84
            json={
85
                'login': login,
86
                'password': password
87
            }
88
        )
89

90
        if response.status_code != 200:
91
            raise ValueError(f'Unable to get user token. Reason: {response.text}')
92

93
        self._user_token = response.cookies['x-jwt-auth']
94
        self.user = json.loads(response.text)
95

96
    def _set_group_token(self, project_id: int, group_token: str = None, group_id: int = None) -> None:
97
        if group_token is not None:
98
            self.group_token = group_token
99
            self._current_project_id = project_id
100
            return
101

102
        if group_id is not None:
103
            group = self._get_execution_group(
104
                project_id=project_id,
105
                group_id=group_id
106
            )
107

108
            self.group_token = group['token']
109
            self._current_project_id = group['project_id']
110

111
            return
112

113
        raise ValueError(f'You have to specify project_id and token/group_id!')
114

115
    def _get_execution_groups(self, project_id: int) -> List[dict]:
116
        response = requests.get(
117
            url=f'{self.controller_server}/execution-groups/{project_id}',
118
            cookies={USER_TOKEN_KEY: self._user_token}
119
        )
120

121
        if response.status_code != 200:
122
            raise ValueError(f'Unable to get execution groups. Reason: {response.text}')
123

124
        return json.loads(response.text)
125

126
    def _get_execution_group(self, project_id: int, group_id: int) -> dict:
127
        response = requests.get(
128
            url=f'{self.controller_server}/execution-groups/{project_id}/{group_id}',
129
            cookies={USER_TOKEN_KEY: self._user_token}
130
        )
131

132
        if response.status_code != 200:
133
            raise ValueError(f'Unable to get execution group. Reason: {response.text}')
134

135
        return json.loads(response.text)
136

137
    def _create_execution_group(self, project_id: int) -> dict:
138
        response = requests.post(
139
            url=f'{self.controller_server}/execution-groups/{project_id}',
140
            cookies={USER_TOKEN_KEY: self._user_token}
141
        )
142

143
        if response.status_code != 200:
144
            raise ValueError(f'Unable to create execution group. Reason: {response.text}')
145

146
        return json.loads(response.text)
147

148
    def _get_executions(self):
149
        response = requests.get(
150
            url=f'{self.controller_server}/executions/{self._current_project_id}',
151
            cookies={USER_TOKEN_KEY: self._user_token},
152
            headers={GROUP_TOKEN_KEY: self.group_token}
153
        )
154

155
        if response.status_code != 200:
156
            raise ValueError(f'Unable to get executions. Reason: {response.text}')
157

158
        return json.loads(response.text)
159

160
    def _get_execution(self, execution_id: int):
161
        response = requests.get(
162
            url=f'{self.controller_server}/executions/{self._current_project_id}/{execution_id}',
163
            cookies={USER_TOKEN_KEY: self._user_token},
164
            headers={GROUP_TOKEN_KEY: self.group_token}
165
        )
166

167
        if response.status_code != 200:
168
            raise ValueError(f'Unable to get execution. Reason: {response.text}')
169

170
        return json.loads(response.text)
171

172
    def _create_execution(self, container_input_path: str,
173
                          container_output_path: str,
174
                          container_config_path: str,
175
                          container_image: str,
176
                          timeout: int,
177
                          config: bytes) -> dict:
178
        response = requests.post(
179
            url=f'{self.controller_server}/executions/{self._current_project_id}',
180
            cookies={USER_TOKEN_KEY: self._user_token},
181
            headers={GROUP_TOKEN_KEY: self.group_token},
182
            files={
183
                'input_path': (None, container_input_path),
184
                'output_path': (None, container_output_path),
185
                'config_path': (None, container_config_path),
186
                'image': (None, container_image),
187
                'timeout': (None, timeout),
188
                'config_file': ('config', config)
189
            }
190
        )
191

192
        if response.status_code != 200:
193
            raise ValueError(f'Unable to create execution. Reason: {response.text}')
194

195
        return json.loads(response.text)
196

197
    def _stop_execution(self, execution_id: int) -> None:
198
        response = requests.delete(
199
            url=f'{self.controller_server}/executions/{self._current_project_id}/{execution_id}',
200
            cookies={USER_TOKEN_KEY: self._user_token},
201
            headers={GROUP_TOKEN_KEY: self.group_token}
202
        )
203

204
        if response.status_code != 204:
205
            raise ValueError(f'Unable to stop execution. Reason: {response.text}')
206

207
    def download_result(self, execution_id: int, result_cls=Pipeline) -> Pipeline:
208
        response = requests.get(
209
            url=f'{self.controller_server}/executions/{self._current_project_id}/{execution_id}/download',
210
            cookies={USER_TOKEN_KEY: self._user_token},
211
            headers={GROUP_TOKEN_KEY: self.group_token},
212
            stream=True
213
        )
214

215
        if response.status_code != 200:
216
            raise ValueError(f'Unable to download results. Reason: {response.text}')
217

218
        tmp_path = f'_tmp_{int(datetime.utcnow().timestamp() * 1000)}'
219
        try:
220
            with open(tmp_path, 'wb') as tmp_file:
221
                shutil.copyfileobj(response.raw, tmp_file)
222
            shutil.unpack_archive(tmp_path, f'{self.output_path}/execution-{execution_id}', 'zip')
223
        finally:
224
            try:
225
                os.remove(tmp_path)
226
            except FileNotFoundError:
227
                pass
228

229
        results_path_out = os.path.join(self.output_path,
230
                                        f'execution-{execution_id}',
231
                                        'out')
232
        results_folder = os.listdir(results_path_out)[0]
233
        load_path = os.path.join(results_path_out, results_folder, 'fitted_pipeline.json')
234
        pipeline = result_cls.from_serialized(load_path)
235

236
        clean_dir(results_path_out)
237
        return pipeline
238

239

240
def clean_dir(results_path_out):
241
    for root, dirs, files in os.walk(results_path_out):
242
        for file in files:
243
            os.remove(os.path.join(root, file))
244

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.