pytorch-lightning

Форк
0
582 строки · 18.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import abc
16
import asyncio
17
import builtins
18
import enum
19
import functools
20
import inspect
21
import json
22
import logging
23
import os
24
import sys
25
import threading
26
import time
27
from abc import ABC, abstractmethod
28
from contextlib import contextmanager
29
from copy import deepcopy
30
from dataclasses import dataclass, field
31
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Type
32
from unittest.mock import MagicMock
33

34
import websockets
35
from deepdiff import Delta
36

37
import lightning.app
38
from lightning.app.utilities.exceptions import LightningAppStateException
39
from lightning.app.utilities.tree import breadth_first
40

41
if TYPE_CHECKING:
42
    from lightning.app.core.app import LightningApp
43
    from lightning.app.core.flow import LightningFlow
44
    from lightning.app.utilities.types import Component
45

46
logger = logging.getLogger(__name__)
47

48

49
@dataclass
50
class StateEntry:
51
    """Dataclass used to keep track the latest state shared through the app REST API."""
52

53
    app_state: Mapping = field(default_factory=dict)
54
    served_state: Mapping = field(default_factory=dict)
55
    session_id: Optional[str] = None
56

57

58
class StateStore(ABC):
59
    """Base class of State store that provides simple key, value store to keep track of app state, served app state."""
60

61
    @abstractmethod
62
    def __init__(self):
63
        pass
64

65
    @abstractmethod
66
    def add(self, k: str):
67
        """Creates a new empty state with input key 'k'."""
68
        pass
69

70
    @abstractmethod
71
    def remove(self, k: str):
72
        """Deletes a state with input key 'k'."""
73
        pass
74

75
    @abstractmethod
76
    def get_app_state(self, k: str) -> Mapping:
77
        """Returns a stored appstate for an input key 'k'."""
78
        pass
79

80
    @abstractmethod
81
    def get_served_state(self, k: str) -> Mapping:
82
        """Returns a last served app state for an input key 'k'."""
83
        pass
84

85
    @abstractmethod
86
    def get_served_session_id(self, k: str) -> str:
87
        """Returns session id for state of a key 'k'."""
88
        pass
89

90
    @abstractmethod
91
    def set_app_state(self, k: str, v: Mapping):
92
        """Sets the app state for state of a key 'k'."""
93
        pass
94

95
    @abstractmethod
96
    def set_served_state(self, k: str, v: Mapping):
97
        """Sets the served state for state of a key 'k'."""
98
        pass
99

100
    @abstractmethod
101
    def set_served_session_id(self, k: str, v: str):
102
        """Sets the session id for state of a key 'k'."""
103
        pass
104

105

106
class InMemoryStateStore(StateStore):
107
    """In memory simple store to keep track of state through the app REST API."""
108

109
    def __init__(self):
110
        self.store = {}
111
        self.counter = 0
112

113
    def add(self, k):
114
        self.store[k] = StateEntry()
115

116
    def remove(self, k):
117
        del self.store[k]
118

119
    def get_app_state(self, k):
120
        return self.store[k].app_state
121

122
    def get_served_state(self, k):
123
        return self.store[k].served_state
124

125
    def get_served_session_id(self, k):
126
        return self.store[k].session_id
127

128
    def set_app_state(self, k, v):
129
        state_size = sys.getsizeof(v)
130
        if state_size > lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES:
131
            raise LightningAppStateException(
132
                f"App state size is {state_size} bytes, which is larger than the recommended size "
133
                f"of {lightning.app.core.constants.APP_STATE_MAX_SIZE_BYTES}. Please investigate this."
134
            )
135
        self.store[k].app_state = deepcopy(v)
136
        self.counter += 1
137

138
    def set_served_state(self, k, v):
139
        self.store[k].served_state = deepcopy(v)
140

141
    def set_served_session_id(self, k, v):
142
        self.store[k].session_id = v
143

144

145
class _LightningAppRef:
146
    _app_instance: Optional["LightningApp"] = None
147

148
    @classmethod
149
    def connect(cls, app_instance: "LightningApp") -> None:
150
        cls._app_instance = app_instance
151

152
    @classmethod
153
    def get_current(cls) -> Optional["LightningApp"]:
154
        if cls._app_instance:
155
            return cls._app_instance
156
        return None
157

158

159
def affiliation(component: "Component") -> Tuple[str, ...]:
160
    """Returns the affiliation of a component."""
161
    if component.name in ("root", ""):
162
        return ()
163
    return tuple(component.name.split(".")[1:])
164

165

166
class AppStateType(str, enum.Enum):
167
    STREAMLIT = "STREAMLIT"
168
    DEFAULT = "DEFAULT"
169

170

171
class BaseStatePlugin(abc.ABC):
172
    def __init__(self):
173
        self.authorized = None
174

175
    @abc.abstractmethod
176
    def should_update_app(self, deep_diff):
177
        pass
178

179
    @abc.abstractmethod
180
    def get_context(self):
181
        pass
182

183
    @abc.abstractmethod
184
    def render_non_authorized(self):
185
        pass
186

187

188
class AppStatePlugin(BaseStatePlugin):
189
    def should_update_app(self, deep_diff):
190
        return deep_diff
191

192
    def get_context(self):
193
        return {"type": AppStateType.DEFAULT.value}
194

195
    def render_non_authorized(self):
196
        pass
197

198

199
def target_fn():
200
    try:
201
        # streamlit >= 1.14.0
202
        from streamlit import runtime
203

204
        get_instance = runtime.get_instance
205
        exists = runtime.exists()
206
    except ImportError:
207
        # Older versions
208
        from streamlit.server.server import Server
209

210
        get_instance = Server.get_current
211
        exists = bool(Server._singleton)
212

213
    async def update_fn():
214
        runtime_instance = get_instance()
215
        sessions = list(runtime_instance._session_info_by_id.values())
216
        url = (
217
            "localhost:8080"
218
            if "LIGHTNING_APP_STATE_URL" in os.environ
219
            else f"localhost:{lightning.app.core.constants.APP_SERVER_PORT}"
220
        )
221
        ws_url = f"ws://{url}/api/v1/ws"
222
        last_updated = time.time()
223
        async with websockets.connect(ws_url) as websocket:
224
            while True:
225
                try:
226
                    _ = await websocket.recv()
227

228
                    while (time.time() - last_updated) < 1:
229
                        time.sleep(0.1)
230
                    for session in sessions:
231
                        session = session.session
232
                        session.request_rerun(session._client_state)
233
                    last_updated = time.time()
234
                except websockets.exceptions.ConnectionClosedOK:
235
                    # The websocket is not enabled
236
                    break
237

238
    if exists:
239
        asyncio.run(update_fn())
240

241

242
class StreamLitStatePlugin(BaseStatePlugin):
243
    def __init__(self):
244
        super().__init__()
245
        import streamlit as st
246

247
        if hasattr(st, "session_state") and "websocket_thread" not in st.session_state:
248
            thread = threading.Thread(target=target_fn)
249
            st.session_state.websocket_thread = thread
250
            thread.setDaemon(True)
251
            thread.start()
252

253
    def should_update_app(self, deep_diff):
254
        return deep_diff
255

256
    def get_context(self):
257
        return {"type": AppStateType.DEFAULT.value}
258

259
    def render_non_authorized(self):
260
        pass
261

262

263
def is_overridden(method_name: str, instance: Optional[object] = None, parent: Optional[Type[object]] = None) -> bool:
264
    if instance is None:
265
        return False
266
    if parent is None:
267
        if isinstance(instance, lightning.app.LightningFlow):
268
            parent = lightning.app.LightningFlow
269
        elif isinstance(instance, lightning.app.LightningWork):
270
            parent = lightning.app.LightningWork
271
        if parent is None:
272
            raise ValueError("Expected a parent")
273
    from lightning_utilities.core.overrides import is_overridden
274

275
    return is_overridden(method_name, instance, parent)
276

277

278
def _is_json_serializable(x: Any) -> bool:
279
    """Test whether a variable can be encoded as json."""
280
    if type(x) in lightning.app.core.constants.SUPPORTED_PRIMITIVE_TYPES:
281
        # shortcut for primitive types that are not containers
282
        return True
283
    try:
284
        json.dumps(x, cls=LightningJSONEncoder)
285
        return True
286
    except (TypeError, OverflowError):
287
        # OverflowError is raised if number is too large to encode
288
        return False
289

290

291
def _set_child_name(component: "Component", child: "Component", new_name: str) -> str:
292
    """Computes and sets the name of a child given the parent, and returns the name."""
293
    child_name = f"{component.name}.{new_name}"
294
    child._name = child_name
295

296
    # the name changed, so recursively update the names of the children of this child
297
    if isinstance(child, lightning.app.core.LightningFlow):
298
        for n in child._flows:
299
            c = getattr(child, n)
300
            _set_child_name(child, c, n)
301
        for n in child._works:
302
            c = getattr(child, n)
303
            _set_child_name(child, c, n)
304
        for n in child._structures:
305
            s = getattr(child, n)
306
            _set_child_name(child, s, n)
307
    if isinstance(child, lightning.app.structures.Dict):
308
        for n, c in child.items():
309
            _set_child_name(child, c, n)
310
    if isinstance(child, lightning.app.structures.List):
311
        for c in child:
312
            _set_child_name(child, c, c.name.split(".")[-1])
313

314
    return child_name
315

316

317
def _delta_to_app_state_delta(root: "LightningFlow", component: "Component", delta: Delta) -> Delta:
318
    delta_dict = delta.to_dict()
319
    for changed in delta_dict.values():
320
        for delta_key in changed.copy():
321
            val = changed[delta_key]
322

323
            new_prefix = "root"
324
            for p, c in _walk_to_component(root, component):
325
                if isinstance(c, lightning.app.core.LightningWork):
326
                    new_prefix += "['works']"
327

328
                if isinstance(c, lightning.app.core.LightningFlow):
329
                    new_prefix += "['flows']"
330

331
                if isinstance(c, (lightning.app.structures.Dict, lightning.app.structures.List)):
332
                    new_prefix += "['structures']"
333

334
                c_n = c.name.split(".")[-1]
335
                new_prefix += f"['{c_n}']"
336

337
            delta_key_without_root = delta_key[4:]  # the first 4 chars are the word 'root', strip it
338
            new_key = new_prefix + delta_key_without_root
339
            if new_key != delta_key:
340
                changed[new_key] = val
341
                del changed[delta_key]
342

343
    return Delta(delta_dict)
344

345

346
def _walk_to_component(
347
    root: "LightningFlow",
348
    component: "Component",
349
) -> Generator[Tuple["Component", "Component"], None, None]:
350
    """Returns a generator that runs through the tree starting from the root down to the given component.
351

352
    At each node, yields parent and child as a tuple.
353

354
    """
355
    from lightning.app.structures import Dict, List
356

357
    name_parts = component.name.split(".")[1:]  # exclude 'root' from the name
358
    parent = root
359
    for n in name_parts:
360
        if isinstance(parent, (Dict, List)):
361
            child = parent[n] if isinstance(parent, Dict) else parent[int(n)]
362
        else:
363
            child = getattr(parent, n)
364
        yield parent, child
365
        parent = child
366

367

368
def _collect_child_process_pids(pid: int) -> List[int]:
369
    """Function to return the list of child process pid's of a process."""
370
    processes = os.popen("ps -ej | grep -i 'python' | grep -v 'grep' | awk '{ print $2,$3 }'").read()
371
    processes = [p.split(" ") for p in processes.split("\n")[:-1]]
372
    return [int(child) for child, parent in processes if parent == str(pid) and child != str(pid)]
373

374

375
def _print_to_logger_info(*args: Any, **kwargs: Any):
376
    # TODO Find a better way to re-direct print to loggers.
377
    lightning.app._logger.info(" ".join([str(v) for v in args]))
378

379

380
def convert_print_to_logger_info(func: Callable) -> Callable:
381
    """This function is used to transform any print into logger.info calls, so it gets tracked in the cloud."""
382

383
    @functools.wraps(func)
384
    def wrapper(*args: Any, **kwargs: Any) -> Any:
385
        original_print = __builtins__["print"]
386
        __builtins__["print"] = _print_to_logger_info
387
        res = func(*args, **kwargs)
388
        __builtins__["print"] = original_print
389
        return res
390

391
    return wrapper
392

393

394
def pretty_state(state: Dict) -> Dict:
395
    """Utility to prettify the state by removing hidden attributes."""
396
    new_state = {}
397
    for k, v in state["vars"].items():
398
        if not k.startswith("_"):
399
            if "vars" not in new_state:
400
                new_state["vars"] = {}
401
            new_state["vars"][k] = v
402
    if "flows" in state:
403
        for k, v in state["flows"].items():
404
            if "flows" not in new_state:
405
                new_state["flows"] = {}
406
            new_state["flows"][k] = pretty_state(state["flows"][k])
407
    if "works" in state:
408
        for k, v in state["works"].items():
409
            if "works" not in new_state:
410
                new_state["works"] = {}
411
            new_state["works"][k] = pretty_state(state["works"][k])
412
    return new_state
413

414

415
class LightningJSONEncoder(json.JSONEncoder):
416
    def default(self, obj: Any) -> Any:
417
        if callable(getattr(obj, "__json__", None)):
418
            return obj.__json__()
419
        return json.JSONEncoder.default(self, obj)
420

421

422
class Logger:
423
    """This class is used to improve the debugging experience."""
424

425
    def __init__(self, name: str):
426
        self.logger = logging.getLogger(name)
427
        self.level = None
428

429
    def info(self, msg, *args: Any, **kwargs: Any):
430
        self.logger.info(msg, *args, **kwargs)
431

432
    def warn(self, msg, *args: Any, **kwargs: Any):
433
        self._set_level()
434
        self.logger.warn(msg, *args, **kwargs)
435

436
    def debug(self, msg, *args: Any, **kwargs: Any):
437
        self._set_level()
438
        self.logger.debug(msg, *args, **kwargs)
439

440
    def error(self, msg, *args: Any, **kwargs: Any):
441
        self._set_level()
442
        self.logger.error(msg, *args, **kwargs)
443

444
    def _set_level(self):
445
        """Lazily set the level once set by the users."""
446
        # Set on the first from either log, warn, debug or error call.
447
        if self.level is None:
448
            self.level = logging.DEBUG if bool(int(os.getenv("LIGHTNING_DEBUG", "0"))) else logging.INFO
449
            self.logger.setLevel(self.level)
450

451

452
def _state_dict(flow: "LightningFlow"):
453
    state = {}
454
    flows = [flow] + list(flow.flows.values())
455
    for f in flows:
456
        state[f.name] = f.state_dict()
457
    for w in flow.works():
458
        state[w.name] = w.state
459
    return state
460

461

462
def _load_state_dict(root_flow: "LightningFlow", state: Dict[str, Any], strict: bool = True) -> None:
463
    """This function is used to reload the state assuming dynamic components creation.
464

465
    When a component isn't found but its state exists, its state is passed up to its closest existing parent.
466

467
    Arguments:
468
        root_flow: The flow at the top of the component tree.
469
        state: The collected state dict.
470
        strict: Whether to validate all components have been re-created.
471

472
    """
473
    # 1: Reload the state of the existing works
474
    for w in root_flow.works():
475
        w.set_state(state.pop(w.name))
476

477
    # 2: Collect the existing flows
478
    flows = [root_flow] + list(root_flow.flows.values())
479
    flow_map = {f.name: f for f in flows}
480

481
    # 3: Find the state of the all dynamic components
482
    dynamic_components = {k: v for k, v in state.items() if k not in flow_map}
483

484
    # 4: Propagate the state of the dynamic components to their closest parents
485
    dynamic_children_state = {}
486
    for name, component_state in dynamic_components.items():
487
        affiliation = name.split(".")
488
        for idx in range(0, len(affiliation)):
489
            parent_name = ".".join(affiliation[:-idx])
490
            has_matched = False
491
            for flow_name, flow in flow_map.items():
492
                if flow_name == parent_name:
493
                    if flow_name not in dynamic_children_state:
494
                        dynamic_children_state[flow_name] = {}
495

496
                    dynamic_children_state[flow_name].update({name.replace(parent_name + ".", ""): component_state})
497
                    has_matched = True
498
                    break
499
            if has_matched:
500
                break
501

502
    # 5: Reload the flow states
503
    for flow_name, flow in flow_map.items():
504
        flow.load_state_dict(state.pop(flow_name), dynamic_children_state.get(flow_name, {}), strict=strict)
505

506
    # 6: Verify all dynamic components has been re-created.
507
    if strict:
508
        components_names = (
509
            [root_flow.name] + [f.name for f in root_flow.flows.values()] + [w.name for w in root_flow.works()]
510
        )
511
        for component_name in dynamic_components:
512
            if component_name not in components_names:
513
                raise Exception(f"The component {component_name} was re-created during state reloading.")
514

515

516
class _MagicMockJsonSerializable(MagicMock):
517
    @staticmethod
518
    def __json__():
519
        return "{}"
520

521

522
def _mock_import(*args, original_fn=None):
523
    try:
524
        return original_fn(*args)
525
    except Exception:
526
        return _MagicMockJsonSerializable()
527

528

529
@contextmanager
530
def _mock_missing_imports():
531
    original_fn = builtins.__import__
532
    builtins.__import__ = functools.partial(_mock_import, original_fn=original_fn)
533
    try:
534
        yield
535
    finally:
536
        builtins.__import__ = original_fn
537

538

539
def is_static_method(klass_or_instance, attr) -> bool:
540
    return isinstance(inspect.getattr_static(klass_or_instance, attr), staticmethod)
541

542

543
def _lightning_dispatched() -> bool:
544
    return bool(int(os.getenv("LIGHTNING_DISPATCHED", 0)))
545

546

547
def _using_debugger() -> bool:
548
    """This method is used to detect whether the app is run with a debugger attached."""
549
    if "LIGHTNING_DETECTED_DEBUGGER" in os.environ:
550
        return True
551

552
    # Collect the information about the process.
553
    parent_process = os.popen(f"ps -ax | grep -i {os.getpid()} | grep -v grep").read()
554

555
    # Detect whether VSCode or PyCharm debugger are used
556
    use_debugger = "debugpy" in parent_process or "pydev" in parent_process
557

558
    # Store the result to avoid multiple popen calls.
559
    if use_debugger:
560
        os.environ["LIGHTNING_DETECTED_DEBUGGER"] = "1"
561
    return use_debugger
562

563

564
def _should_dispatch_app() -> bool:
565
    return (
566
        not _lightning_dispatched()
567
        and "LIGHTNING_APP_STATE_URL" not in os.environ
568
        # Keep last to avoid running it if already dispatched
569
        and _using_debugger()
570
    )
571

572

573
def _is_headless(app: "LightningApp") -> bool:
574
    """Utility which returns True if the given App has no ``Frontend`` objects or URLs exposed through
575
    ``configure_layout``."""
576
    if app.frontends:
577
        return False
578
    for component in breadth_first(app.root, types=(lightning.app.LightningFlow,)):
579
        for entry in component._layout:
580
            if "target" in entry:
581
                return False
582
    return True
583

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.