pytorch-lightning

Форк
0
174 строки · 5.5 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import abc
16
import inspect
17
import os
18
import pydoc
19
import subprocess
20
import sys
21
from typing import Any, Callable, Type
22

23
from lightning.app.core.work import LightningWork
24
from lightning.app.utilities.app_helpers import StreamLitStatePlugin
25
from lightning.app.utilities.state import AppState
26

27

28
class ServeStreamlit(LightningWork, abc.ABC):
29
    """The ``ServeStreamlit`` work allows you to use streamlit from a work.
30

31
    You can optionally build a model in the ``build_model`` hook, which will only be called once per session.
32

33
    """
34

35
    def __init__(self, *args: Any, **kwargs: Any):
36
        super().__init__(*args, **kwargs)
37

38
        self.ready = False
39

40
        self._process = None
41

42
    @property
43
    def model(self) -> Any:
44
        return getattr(self, "_model", None)
45

46
    @abc.abstractmethod
47
    def render(self) -> None:
48
        """Override with your streamlit render function."""
49

50
    def build_model(self) -> Any:
51
        """Optionally override to instantiate and return your model.
52

53
        The model will be accessible under ``self.model``.
54

55
        """
56
        return None
57

58
    def run(self) -> None:
59
        env = os.environ.copy()
60
        env["LIGHTNING_COMPONENT_NAME"] = self.name
61
        env["LIGHTNING_WORK"] = self.__class__.__name__
62
        env["LIGHTNING_WORK_MODULE_FILE"] = inspect.getmodule(self).__file__
63
        self._process = subprocess.Popen(
64
            [
65
                sys.executable,
66
                "-m",
67
                "streamlit",
68
                "run",
69
                __file__,
70
                "--server.address",
71
                str(self.host),
72
                "--server.port",
73
                str(self.port),
74
                "--server.headless",
75
                "true",  # do not open the browser window when running locally
76
            ],
77
            env=env,
78
        )
79
        self.ready = True
80
        self._process.wait()
81

82
    def on_exit(self) -> None:
83
        if self._process is not None:
84
            self._process.kill()
85

86
    def configure_layout(self) -> str:
87
        return self.url
88

89

90
class _PatchedWork:
91
    """The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the self
92
    reference in methods an properties to point to the AppState.
93

94
    Args:
95
        state: The work state to patch
96
        work_class: The work class to emulate
97

98
    """
99

100
    def __init__(self, state: AppState, work_class: Type):
101
        super().__init__()
102
        self._state = state
103
        self._work_class = work_class
104

105
    def __getattr__(self, name: str) -> Any:
106
        try:
107
            return getattr(self._state, name)
108
        except AttributeError:
109
            # The name isn't in the state, so check if it's a callable or a property
110
            attribute = inspect.getattr_static(self._work_class, name)
111
            if callable(attribute):
112
                attribute = attribute.__get__(self, self._work_class)
113
                return attribute
114
            if isinstance(attribute, (staticmethod, property)):
115
                return attribute.__get__(self, self._work_class)
116

117
            # Look for the name in the instance (e.g. for private variables)
118
            return object.__getattribute__(self, name)
119

120
    def __setattr__(self, name: str, value: Any) -> None:
121
        if name in ["_state", "_work_class"]:
122
            return object.__setattr__(self, name, value)
123

124
        if hasattr(self._state, name):
125
            return setattr(self._state, name, value)
126
        return object.__setattr__(self, name, value)
127

128

129
def _reduce_to_component_scope(state: AppState, component_name: str) -> AppState:
130
    """Given the app state, this utility traverses down to the level of the given component name."""
131
    component_name_parts = component_name.split(".")[1:]  # exclude root
132
    component_state = state
133
    for part in component_name_parts:
134
        component_state = getattr(component_state, part)
135
    return component_state
136

137

138
def _get_work_class() -> Callable:
139
    """Import the work class specified in the environment."""
140
    work_name = os.environ["LIGHTNING_WORK"]
141
    work_module_file = os.environ["LIGHTNING_WORK_MODULE_FILE"]
142
    module = pydoc.importfile(work_module_file)
143
    return getattr(module, work_name)
144

145

146
def _build_model(work: ServeStreamlit) -> None:
147
    import streamlit as st
148

149
    # Build the model (once per session, equivalent to gradio when enable_queue is Flase)
150
    if "_model" not in st.session_state:
151
        with st.spinner("Building model..."):
152
            st.session_state["_model"] = work.build_model()
153

154
    work._model = st.session_state["_model"]
155

156

157
def _main() -> None:
158
    # Get the AppState
159
    app_state = AppState(plugin=StreamLitStatePlugin())
160
    work_state = _reduce_to_component_scope(app_state, os.environ["LIGHTNING_COMPONENT_NAME"])
161

162
    # Create the patched work
163
    work_class = _get_work_class()
164
    patched_work = _PatchedWork(work_state, work_class)
165

166
    # Build and attach the model
167
    _build_model(patched_work)
168

169
    # Render
170
    patched_work.render()
171

172

173
if __name__ == "__main__":
174
    _main()
175

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.