pytorch-lightning

Форк
0
737 строк · 29.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import logging
16
import os
17
import pickle
18
import queue
19
import threading
20
import warnings
21
from copy import deepcopy
22
from time import time
23
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
24

25
from deepdiff import DeepDiff, Delta
26
from lightning_utilities.core.apply_func import apply_to_collection
27

28
import lightning.app
29
from lightning.app import _console
30
from lightning.app.api.request_types import _APIRequest, _CommandRequest, _DeltaRequest
31
from lightning.app.core.constants import (
32
    BATCH_DELTA_COUNT,
33
    DEBUG_ENABLED,
34
    FLOW_DURATION_SAMPLES,
35
    FLOW_DURATION_THRESHOLD,
36
    FRONTEND_DIR,
37
    STATE_ACCUMULATE_WAIT,
38
)
39
from lightning.app.core.queues import BaseQueue
40
from lightning.app.core.work import LightningWork
41
from lightning.app.frontend import Frontend
42
from lightning.app.storage import Drive, Path, Payload
43
from lightning.app.storage.path import _storage_root_dir
44
from lightning.app.utilities import frontend
45
from lightning.app.utilities.app_helpers import (
46
    Logger,
47
    _delta_to_app_state_delta,
48
    _LightningAppRef,
49
    _should_dispatch_app,
50
)
51
from lightning.app.utilities.app_status import AppStatus
52
from lightning.app.utilities.commands.base import _process_requests
53
from lightning.app.utilities.component import _convert_paths_after_init, _validate_root_flow
54
from lightning.app.utilities.enum import AppStage, CacheCallsKeys
55
from lightning.app.utilities.exceptions import CacheMissException, ExitAppException, LightningFlowException
56
from lightning.app.utilities.layout import _collect_layout
57
from lightning.app.utilities.proxies import ComponentDelta
58
from lightning.app.utilities.scheduler import SchedulerThread
59
from lightning.app.utilities.tree import breadth_first
60
from lightning.app.utilities.warnings import LightningFlowWarning
61

62
if TYPE_CHECKING:
63
    from lightning.app.core.flow import LightningFlow
64
    from lightning.app.runners.backends.backend import Backend, WorkManager
65
    from lightning.app.runners.runtime import Runtime
66
    from lightning.app.utilities.packaging.cloud_compute import CloudCompute
67

68

69
logger = Logger(__name__)
70

71

72
class LightningApp:
73
    def __init__(
74
        self,
75
        root: Union["LightningFlow", LightningWork],
76
        flow_cloud_compute: Optional["CloudCompute"] = None,
77
        log_level: str = "info",
78
        info: Optional[frontend.AppInfo] = None,
79
        root_path: str = "",
80
    ) -> None:
81
        """The Lightning App, or App in short runs a tree of one or more components that interact to create end-to-end
82
        applications. There are two kinds of components: :class:`~lightning.app.core.flow.LightningFlow` and
83
        :class:`~lightning.app.core.work.LightningWork`. This modular design enables you to reuse components created by
84
        other users.
85

86
        The Lightning App alternatively run an event loop triggered by delta changes sent from
87
        either :class:`~lightning.app.core.work.LightningWork` or from the Lightning UI.
88
        Once deltas are received, the Lightning App runs
89
        the :class:`~lightning.app.core.flow.LightningFlow` provided.
90

91
        Arguments:
92
            root: The root ``LightningFlow`` or ``LightningWork`` component, that defines all the app's nested
93
                 components, running infinitely. It must define a `run()` method that the app can call.
94
            flow_cloud_compute: The default Cloud Compute used for flow, Rest API and frontend's.
95
            log_level: The log level for the app, one of [`info`, `debug`].
96
                This can be helpful when reporting bugs on Lightning repo.
97
            info: Provide additional info about the app which will be used to update html title,
98
                description and image meta tags and specify any additional tags as list of html strings.
99
            root_path: Set this to `/path` if you want to run your app behind a proxy at `/path` leave empty for "/".
100
                For instance, if you want to run your app at `https://customdomain.com/myapp`,
101
                set `root_path` to `/myapp`.
102
                You can learn more about proxy `here <https://www.fortinet.com/resources/cyberglossary/proxy-server>`_.
103

104
        """
105

106
        self.root_path = root_path  # when running behind a proxy
107
        self.info = info
108

109
        from lightning.app.core.flow import _RootFlow
110

111
        if isinstance(root, LightningWork):
112
            root = _RootFlow(root)
113

114
        _validate_root_flow(root)
115
        self._root = root
116
        self.flow_cloud_compute = flow_cloud_compute or lightning.app.CloudCompute(name="flow-lite")
117

118
        # queues definition.
119
        self.delta_queue: Optional[BaseQueue] = None
120
        self.readiness_queue: Optional[BaseQueue] = None
121
        self.api_response_queue: Optional[BaseQueue] = None
122
        self.api_publish_state_queue: Optional[BaseQueue] = None
123
        self.api_delta_queue: Optional[BaseQueue] = None
124
        self.error_queue: Optional[BaseQueue] = None
125
        self.request_queues: Optional[Dict[str, BaseQueue]] = None
126
        self.response_queues: Optional[Dict[str, BaseQueue]] = None
127
        self.copy_request_queues: Optional[Dict[str, BaseQueue]] = None
128
        self.copy_response_queues: Optional[Dict[str, BaseQueue]] = None
129
        self.caller_queues: Optional[Dict[str, BaseQueue]] = None
130
        self.flow_to_work_delta_queues: Optional[Dict[str, BaseQueue]] = None
131
        self.work_queues: Optional[Dict[str, BaseQueue]] = None
132
        self.commands: Optional[List] = None
133

134
        self.should_publish_changes_to_api = False
135
        self.component_affiliation = None
136
        self.backend: Optional["Backend"] = None
137
        _LightningAppRef.connect(self)
138
        self.processes: Dict[str, "WorkManager"] = {}
139
        self.frontends: Dict[str, Frontend] = {}
140
        self.stage = AppStage.RUNNING
141
        self._has_updated: bool = True
142
        self._schedules: Dict[str, Dict] = {}
143
        self.threads: List[threading.Thread] = []
144
        self.exception = None
145
        self.collect_changes: bool = True
146

147
        self.status: Optional[AppStatus] = None
148
        # TODO: Enable ready locally for opening the UI.
149
        self.ready = False
150

151
        # NOTE: Checkpointing is disabled by default for the time being.  We
152
        # will enable it when resuming from full checkpoint is supported. Also,
153
        # we will need to revisit the logic at _should_snapshot, since right now
154
        # we are writing checkpoints too often, and this is expensive.
155
        self.checkpointing: bool = False
156

157
        self._update_layout()
158
        self._update_status()
159

160
        self.is_headless: Optional[bool] = None
161

162
        self._original_state: Optional[dict] = None
163
        self._last_state: dict = self.state
164
        self.state_accumulate_wait = STATE_ACCUMULATE_WAIT
165

166
        self._last_run_time: float = 0.0
167
        self._run_times: list = []
168

169
        # Path attributes can't get properly attached during the initialization, because the full name
170
        # is only available after all Flows and Works have been instantiated.
171
        _convert_paths_after_init(self.root)  # type: ignore[arg-type]
172

173
        if log_level not in ("debug", "info"):
174
            raise Exception(f"Log Level should be in ['debug', 'info']. Found {log_level}")
175

176
        # Lazily enable debugging.
177
        if log_level == "debug" or DEBUG_ENABLED:
178
            if not DEBUG_ENABLED:
179
                os.environ["LIGHTNING_DEBUG"] = "2"
180
            _console.setLevel(logging.DEBUG)
181

182
        logger.debug(f"ENV: {os.environ}")
183

184
        if _should_dispatch_app():
185
            os.environ["LIGHTNING_DISPATCHED"] = "1"
186
            from lightning.app.runners import MultiProcessRuntime
187

188
            MultiProcessRuntime(self).dispatch()
189

190
    def _update_index_file(self) -> None:
191
        # update index.html,
192
        # this should happen once for all apps before the ui server starts running.
193
        frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path)
194

195
    def get_component_by_name(self, component_name: str) -> Union["LightningFlow", LightningWork]:
196
        """Returns the instance corresponding to the given component name."""
197
        from lightning.app.structures import Dict as LightningDict
198
        from lightning.app.structures import List as LightningList
199
        from lightning.app.utilities.types import ComponentTuple
200

201
        if component_name == "root":
202
            return self.root
203
        if not component_name.startswith("root."):
204
            raise ValueError(f"Invalid component name {component_name}. Name must start with 'root'")
205

206
        current = self.root
207
        for child_name in component_name.split(".")[1:]:
208
            if isinstance(current, LightningDict):
209
                child = current[child_name]
210
            elif isinstance(current, LightningList):
211
                child = current[int(child_name)]
212
            else:
213
                child = getattr(current, child_name, None)
214
            if not isinstance(child, ComponentTuple):
215
                raise AttributeError(f"Component '{current.name}' has no child component with name '{child_name}'.")
216
            current = child  # type: ignore[assignment]
217
        return current
218

219
    def _reset_original_state(self) -> None:
220
        assert self._original_state is not None
221
        self.set_state(self._original_state)
222

223
    @property
224
    def root(self) -> Union["LightningFlow", LightningWork]:
225
        """Returns the root component of the application."""
226
        return self._root
227

228
    @property
229
    def state(self) -> dict:
230
        """Return the current state of the application."""
231
        state = self.root.state
232
        state["app_state"] = {"stage": self.stage.value}
233
        return state
234

235
    @property
236
    def state_vars(self) -> dict:
237
        """Return the current state restricted to the user defined variables of the application."""
238
        state_vars = self.root.state_vars
239
        state_vars["app_state"] = {"stage": self.stage.value}
240
        return state_vars
241

242
    @property
243
    def state_with_changes(self) -> dict:
244
        """Return the current state with the new changes of the application."""
245
        state_with_changes = self.root.state_with_changes
246
        state_with_changes["app_state"] = {"stage": self.stage.value}
247
        return state_with_changes
248

249
    def set_state(self, state: dict) -> None:
250
        """Method to set a new app state set to the application."""
251
        self.set_last_state(state)
252
        self.root.set_state(state)
253
        self.stage = AppStage(state["app_state"]["stage"])
254

255
    @property
256
    def last_state(self) -> dict:
257
        """Returns the latest state."""
258
        return self._last_state
259

260
    @property
261
    def checkpoint_dir(self) -> str:
262
        return os.path.join(str(_storage_root_dir()), "checkpoints")
263

264
    def remove_changes_(self, state: dict) -> None:
265
        for _, child in state["flows"].items():
266
            self.remove_changes(child)
267
        state["changes"] = {}
268

269
    def remove_changes(self, state: dict) -> dict:
270
        state = deepcopy(state)
271
        for _, child in state["flows"].items():
272
            self.remove_changes_(child)
273
        state["changes"] = {}
274
        return state
275

276
    def set_last_state(self, state: dict) -> None:
277
        self._last_state = self.remove_changes(state)
278

279
    @staticmethod
280
    def populate_changes(last_state: dict, new_state: dict) -> dict:
281
        diff = DeepDiff(last_state, new_state, view="tree", verbose_level=2)
282

283
        changes_categories = [diff[key] for key in diff.to_dict()]
284

285
        if not changes_categories:
286
            return new_state
287

288
        for change_category in changes_categories:
289
            for entry in change_category:
290
                state_el = new_state
291
                change = entry.path(output_format="list")
292
                if "vars" not in change:
293
                    continue
294
                for change_el in change:
295
                    if change_el == "vars":
296
                        if "changes" not in state_el:
297
                            state_el["changes"] = {}
298
                        state_el["changes"][change[-1]] = {"from": entry.t1, "to": entry.t2}
299
                        break
300
                    # move down in the dictionary
301
                    state_el = state_el[change_el]
302
        return new_state
303

304
    @staticmethod
305
    def get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> Optional[dict]:
306
        try:
307
            timeout = timeout or q.default_timeout
308
            return q.get(timeout=timeout)
309
        except queue.Empty:
310
            return None
311

312
    @staticmethod
313
    def batch_get_state_changed_from_queue(q: BaseQueue, timeout: Optional[float] = None) -> List[dict]:
314
        try:
315
            timeout = timeout or q.default_timeout
316
            return q.batch_get(timeout=timeout, count=BATCH_DELTA_COUNT)
317
        except queue.Empty:
318
            return []
319

320
    def check_error_queue(self) -> None:
321
        exception: Exception = self.get_state_changed_from_queue(self.error_queue)  # type: ignore[assignment,arg-type]
322
        if isinstance(exception, Exception):
323
            self.exception = exception
324
            self.stage = AppStage.FAILED
325

326
    @property
327
    def flows(self) -> List[Union[LightningWork, "LightningFlow"]]:
328
        """Returns all the flows defined within this application."""
329
        return [self.root] + list(self.root.flows.values())
330

331
    @property
332
    def works(self) -> List[LightningWork]:
333
        """Returns all the works defined within this application."""
334
        return self.root.works(recurse=True)
335

336
    @property
337
    def named_works(self) -> List[Tuple[str, LightningWork]]:
338
        """Returns all the works defined within this application with their names."""
339
        return self.root.named_works(recurse=True)
340

341
    def _collect_deltas_from_ui_and_work_queues(self) -> List[Union[Delta, _APIRequest, _CommandRequest]]:
342
        # The aggregation would try to get as many deltas as possible
343
        # from both the `api_delta_queue` and `delta_queue`
344
        # during the `state_accumulate_wait` time.
345
        # The while loop can exit sooner if both queues are empty.
346

347
        deltas = []
348
        api_or_command_request_deltas = []
349
        t0 = time()
350

351
        while (time() - t0) < self.state_accumulate_wait:
352
            # TODO: Fetch all available deltas at once to reduce queue calls.
353
            received_deltas: List[Union[_DeltaRequest, _APIRequest, _CommandRequest, ComponentDelta]] = (
354
                self.batch_get_state_changed_from_queue(
355
                    self.delta_queue  # type: ignore[assignment,arg-type]
356
                )
357
            )
358
            if len(received_deltas) == []:
359
                break
360

361
            for delta in received_deltas:
362
                if isinstance(delta, _DeltaRequest):
363
                    deltas.append(delta.delta)
364
                elif isinstance(delta, ComponentDelta):
365
                    logger.debug(f"Received from {delta.id} : {delta.delta.to_dict()}")
366
                    work = None
367
                    try:
368
                        work = self.get_component_by_name(delta.id)
369
                    except (KeyError, AttributeError) as ex:
370
                        logger.error(f"The component {delta.id} couldn't be accessed. Exception: {ex}")
371

372
                    if work:
373
                        delta = _delta_to_app_state_delta(
374
                            self.root,  # type: ignore[arg-type]
375
                            work,
376
                            deepcopy(delta.delta),
377
                        )
378
                        deltas.append(delta)
379
                else:
380
                    api_or_command_request_deltas.append(delta)
381

382
        if api_or_command_request_deltas:
383
            _process_requests(self, api_or_command_request_deltas)
384

385
        for delta in deltas:
386
            # When aggregating deltas from the UI and the Works, and over the accumulation time window,
387
            # it can happen that deltas from these different sources disagree. Since deltas are computed on the Work
388
            # and UI side separately, correctness of the aggregation can only be guaranteed if both components compute
389
            # the delta based on the same base state. But this assumption does not hold in general, and there is no way
390
            # for the Flow to reject or resolve these deltas properly at the moment. Hence, we decide to ignore
391
            # errors coming from deepdiff when adding deltas together by setting:
392
            delta.log_errors = False  # type: ignore[union-attr]
393
            delta.raise_errors = False  # type: ignore[union-attr]
394
        return deltas
395

396
    def maybe_apply_changes(self) -> Optional[bool]:
397
        """Get the deltas from both the flow queue and the work queue, merge the two deltas and update the state."""
398
        self._send_flow_to_work_deltas(self.state)
399

400
        if not self.collect_changes:
401
            return None
402

403
        deltas = self._collect_deltas_from_ui_and_work_queues()
404

405
        if not deltas:
406
            # Path and Drive aren't processed by DeepDiff, so we need to convert them to dict.
407
            last_state = apply_to_collection(self.last_state, (Path, Drive), lambda x: x.to_dict())
408
            state = apply_to_collection(self.state, (Path, Drive), lambda x: x.to_dict())
409
            # When no deltas are received from the Rest API or work queues,
410
            # we need to check if the flow modified the state and populate changes.
411
            deep_diff = DeepDiff(last_state, state, verbose_level=2)
412

413
            if "unprocessed" in deep_diff:
414
                # pop the unprocessed key.
415
                unprocessed = deep_diff.pop("unprocessed")
416
                logger.warn(f"It seems delta differentiation resulted in {unprocessed}. Open an issue on Github.")
417

418
            if deep_diff:
419
                # TODO: Resolve changes with ``CacheMissException``.
420
                # new_state = self.populate_changes(self.last_state, self.state)
421
                self.set_last_state(self.state)
422
                self._has_updated = True
423
            return False
424

425
        logger.debug(f"Received {[d.to_dict() for d in deltas]}")
426

427
        # 2: Collect the state
428
        state = self.state
429

430
        # 3: Apply the state delta
431
        for delta in deltas:
432
            try:
433
                state += delta
434
            except Exception as ex:
435
                raise Exception(f"Current State {state}, {delta.to_dict()}") from ex
436

437
        # new_state = self.populate_changes(self.last_state, state)
438
        self.set_state(state)
439
        self._has_updated = True
440
        return None
441

442
    def run_once(self) -> bool:
443
        """Method used to collect changes and run the root Flow once."""
444
        done = False
445
        self._last_run_time = 0.0
446

447
        if self.backend is not None:
448
            self.backend.update_work_statuses(self.works)
449

450
        self._update_layout()
451
        self._update_status()
452
        self.maybe_apply_changes()
453

454
        if self.checkpointing and self._should_snapshot():
455
            self._dump_checkpoint()
456

457
        if self.stage == AppStage.BLOCKING:
458
            return done
459

460
        if self.stage in (AppStage.STOPPING, AppStage.FAILED):
461
            return True
462

463
        if self.stage == AppStage.RESTARTING:
464
            return self._apply_restarting()
465

466
        t0 = time()
467

468
        try:
469
            self.check_error_queue()
470
            # Execute the flow only if:
471
            # - There are state changes
472
            # - It is the first execution of the flow
473
            if self._has_updated:
474
                self.root.run()
475
        except CacheMissException:
476
            self._on_cache_miss_exception()
477
        except LightningFlowException:
478
            done = True
479
            self.stage = AppStage.FAILED
480
        except (ExitAppException, KeyboardInterrupt):
481
            done = True
482
            self.stage = AppStage.STOPPING
483

484
        if not self.ready:
485
            self.ready = self.root.ready
486

487
        self._last_run_time = time() - t0
488

489
        self.on_run_once_end()
490
        return done
491

492
    def _reset_run_time_monitor(self) -> None:
493
        self._run_times = [0.0] * FLOW_DURATION_SAMPLES
494

495
    def _update_run_time_monitor(self) -> None:
496
        self._run_times[:-1] = self._run_times[1:]
497
        self._run_times[-1] = self._last_run_time
498

499
        # Here we underestimate during the first FLOW_DURATION_SAMPLES
500
        # iterations, but that's ok for our purposes
501
        avg_elapsed_time = sum(self._run_times) / FLOW_DURATION_SAMPLES
502

503
        if avg_elapsed_time > FLOW_DURATION_THRESHOLD:
504
            warnings.warn(
505
                "The execution of the `run` method of the root flow is taking too long. "
506
                "Flow is supposed to only host coordination logic, while currently it is"
507
                "likely to contain long-running calls, code that performs meaningful "
508
                "computations or that makes blocking or asynchronous calls to third-party "
509
                "services. If that is the case, you should move those pieces to a Work, "
510
                "and make sure Flow can complete its execution in under a second.",
511
                LightningFlowWarning,
512
            )
513

514
    def _run(self) -> bool:
515
        """Entry point of the LightningApp.
516

517
        This would be dispatched by the Runtime objects.
518

519
        """
520
        self._original_state = deepcopy(self.state)
521
        done = False
522

523
        self.ready = self.root.ready
524

525
        self._start_with_flow_works()
526

527
        if self.should_publish_changes_to_api and self.api_publish_state_queue is not None:
528
            self.api_publish_state_queue.put((self.state_vars, self.status))
529

530
        self._reset_run_time_monitor()
531

532
        while not done:
533
            done = self.run_once()
534

535
            self._update_run_time_monitor()
536

537
            if self._has_updated and self.should_publish_changes_to_api and self.api_publish_state_queue is not None:
538
                self.api_publish_state_queue.put((self.state_vars, self.status))
539

540
            self._has_updated = False
541

542
        self._on_run_end()
543

544
        return True
545

546
    def _update_layout(self) -> None:
547
        if self.backend:
548
            self.backend.resolve_url(self, base_url=None)
549

550
        for component in breadth_first(self.root, types=(lightning.app.LightningFlow,)):  # type: ignore[arg-type]
551
            layout = _collect_layout(self, component)
552
            component._layout = layout
553

554
    def _update_status(self) -> None:
555
        old_status = self.status
556

557
        work_statuses = {}
558
        assert self.root is not None
559
        for work in breadth_first(self.root, types=(lightning.app.LightningWork,)):  # type: ignore[arg-type]
560
            work_statuses[work.name] = work.status
561

562
        self.status = AppStatus(
563
            is_ui_ready=self.ready,
564
            work_statuses=work_statuses,
565
        )
566

567
        # If the work statuses changed, the state delta will trigger an update.
568
        # If ready has changed, we trigger an update manually.
569
        if self.status != old_status:
570
            self._has_updated = True
571

572
    def _apply_restarting(self) -> bool:
573
        self._reset_original_state()
574
        # apply stage after restoring the original state.
575
        self.stage = AppStage.BLOCKING
576
        return False
577

578
    def _has_work_finished(self, work: LightningWork) -> bool:
579
        latest_call_hash = work._calls[CacheCallsKeys.LATEST_CALL_HASH]
580
        if latest_call_hash is None:
581
            return False
582
        return "ret" in work._calls[latest_call_hash]
583

584
    def _collect_work_finish_status(self) -> dict:
585
        work_finished_status = {work.name: self._has_work_finished(work) for work in self.works}
586
        assert len(work_finished_status) == len(self.works)
587
        return work_finished_status
588

589
    def _should_snapshot(self) -> bool:
590
        if len(self.works) == 0:
591
            return True
592
        if self._has_updated:
593
            work_finished_status = self._collect_work_finish_status()
594
            if work_finished_status:
595
                return all(work_finished_status.values())
596
            return True
597
        return False
598

599
    def state_dict(self) -> Dict:
600
        return self.state
601

602
    def load_state_dict(self, state: Dict) -> None:
603
        self.set_state(state)
604

605
    def load_state_dict_from_checkpoint_dir(
606
        self,
607
        checkpoints_dir: str,
608
        version: Optional[int] = None,
609
    ) -> None:
610
        if not os.path.exists(checkpoints_dir):
611
            raise FileNotFoundError(f"The provided directory `{checkpoints_dir}` doesn't exist.")
612
        checkpoints = [f for f in os.listdir(checkpoints_dir) if f.startswith("v_") and f.endswith(".json")]
613
        if not checkpoints:
614
            raise Exception(f"No checkpoints where found in `{checkpoints_dir}`.")
615

616
        if version is None:
617
            # take the latest checkpoint.
618
            version = sorted(int(c.split("_")[1]) for c in checkpoints)[-1]
619

620
        available_checkpoints = [c for c in checkpoints if c.startswith(f"v_{version}_")]
621
        if not available_checkpoints:
622
            raise FileNotFoundError(f"The version `{version}` wasn't found in {checkpoints}.")
623
        if len(available_checkpoints) > 1:
624
            raise Exception(f"Found 2 checkpoints `{available_checkpoints}`with the same version.")
625
        checkpoint_path = os.path.join(checkpoints_dir, available_checkpoints[0])
626
        with open(checkpoint_path, "rb") as fo:
627
            state = pickle.load(fo)
628
        self.load_state_dict(state)
629

630
    def _dump_checkpoint(self) -> Optional[str]:
631
        checkpoints_dir = self.checkpoint_dir
632
        # TODO: Add supports to remotely saving checkpoints.
633
        if checkpoints_dir.startswith("s3:"):
634
            return None
635
        os.makedirs(checkpoints_dir, exist_ok=True)
636

637
        # Get all current version within the provided folder and sort them
638
        checkpoint_versions = sorted(
639
            int(f.split("_")[1]) for f in os.listdir(checkpoints_dir) if f.startswith("v_") and f.endswith(".json")
640
        )
641

642
        previous_version = checkpoint_versions[-1] if checkpoint_versions else -1
643

644
        checkpoint_path = os.path.join(checkpoints_dir, f"v_{previous_version + 1}_{time()}.json")
645

646
        with open(checkpoint_path, "wb") as f:
647
            pickle.dump(self.state_dict(), f)
648
        return checkpoint_path
649

650
    def connect(self, runtime: "Runtime") -> None:
651
        """Override to customize your application to the runtime."""
652
        pass
653

654
    def _on_cache_miss_exception(self) -> None:
655
        if self._has_updated:
656
            self._update_layout()
657

658
    def _register_schedule(self, schedule_hash: str, schedule_metadata: Dict) -> None:
659
        # create a thread only if a user uses the flow's schedule method.
660
        if not self._schedules:
661
            scheduler_thread = SchedulerThread(self)
662
            scheduler_thread.setDaemon(True)
663
            self.threads.append(scheduler_thread)
664
            self.threads[-1].start()
665
        self._schedules[schedule_hash] = deepcopy(schedule_metadata)
666

667
    def on_run_once_end(self) -> None:
668
        if not self._schedules:
669
            return
670
        # disable any flow schedules.
671
        for flow in self.flows:
672
            flow._disable_running_schedules()
673

674
    def _on_run_end(self) -> None:
675
        if os.getenv("LIGHTNING_DEBUG") == "2":
676
            del os.environ["LIGHTNING_DEBUG"]
677
            _console.setLevel(logging.INFO)
678

679
    @staticmethod
680
    def _extract_vars_from_component_name(component_name: str, state: dict) -> Optional[dict]:
681
        child = state
682
        for child_name in component_name.split(".")[1:]:
683
            if child_name in child["flows"]:
684
                child = child["flows"][child_name]
685
            elif "structures" in child and child_name in child["structures"]:
686
                child = child["structures"][child_name]
687
            elif child_name in child["works"]:
688
                child = child["works"][child_name]
689
            else:
690
                return None
691

692
        # Filter private keys and drives
693
        return {
694
            k: v
695
            for k, v in child["vars"].items()
696
            if (
697
                not k.startswith("_")
698
                and not (isinstance(v, dict) and v.get("type", None) == "__drive__")
699
                and not (isinstance(v, (Payload, Path)))
700
            )
701
        }
702

703
    def _send_flow_to_work_deltas(self, state: dict) -> None:
704
        if not self.flow_to_work_delta_queues:
705
            return
706

707
        for w in self.works:
708
            if not w.has_started:
709
                continue
710

711
            # Don't send changes when the state has been just sent.
712
            if w.run.has_sent:
713
                continue
714

715
            state_work = self._extract_vars_from_component_name(w.name, state)
716
            last_state_work = self._extract_vars_from_component_name(w.name, self._last_state)
717

718
            # Note: The work was dynamically created or deleted.
719
            if state_work is None or last_state_work is None:
720
                continue
721

722
            deep_diff = DeepDiff(last_state_work, state_work, verbose_level=2).to_dict()
723

724
            if "unprocessed" in deep_diff:
725
                deep_diff.pop("unprocessed")
726

727
            if deep_diff:
728
                logger.debug(f"Sending deep_diff to {w.name} : {deep_diff}")
729
                self.flow_to_work_delta_queues[w.name].put(deep_diff)
730

731
    def _start_with_flow_works(self) -> None:
732
        for w in self.works:
733
            if w._start_with_flow:
734
                parallel = w.parallel
735
                w._parallel = True
736
                w.start()
737
                w._parallel = parallel
738

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.