pytorch-lightning
174 строки · 5.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import abc
16import inspect
17import os
18import pydoc
19import subprocess
20import sys
21from typing import Any, Callable, Type
22
23from lightning.app.core.work import LightningWork
24from lightning.app.utilities.app_helpers import StreamLitStatePlugin
25from lightning.app.utilities.state import AppState
26
27
28class ServeStreamlit(LightningWork, abc.ABC):
29"""The ``ServeStreamlit`` work allows you to use streamlit from a work.
30
31You can optionally build a model in the ``build_model`` hook, which will only be called once per session.
32
33"""
34
35def __init__(self, *args: Any, **kwargs: Any):
36super().__init__(*args, **kwargs)
37
38self.ready = False
39
40self._process = None
41
42@property
43def model(self) -> Any:
44return getattr(self, "_model", None)
45
46@abc.abstractmethod
47def render(self) -> None:
48"""Override with your streamlit render function."""
49
50def build_model(self) -> Any:
51"""Optionally override to instantiate and return your model.
52
53The model will be accessible under ``self.model``.
54
55"""
56return None
57
58def run(self) -> None:
59env = os.environ.copy()
60env["LIGHTNING_COMPONENT_NAME"] = self.name
61env["LIGHTNING_WORK"] = self.__class__.__name__
62env["LIGHTNING_WORK_MODULE_FILE"] = inspect.getmodule(self).__file__
63self._process = subprocess.Popen(
64[
65sys.executable,
66"-m",
67"streamlit",
68"run",
69__file__,
70"--server.address",
71str(self.host),
72"--server.port",
73str(self.port),
74"--server.headless",
75"true", # do not open the browser window when running locally
76],
77env=env,
78)
79self.ready = True
80self._process.wait()
81
82def on_exit(self) -> None:
83if self._process is not None:
84self._process.kill()
85
86def configure_layout(self) -> str:
87return self.url
88
89
90class _PatchedWork:
91"""The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the self
92reference in methods an properties to point to the AppState.
93
94Args:
95state: The work state to patch
96work_class: The work class to emulate
97
98"""
99
100def __init__(self, state: AppState, work_class: Type):
101super().__init__()
102self._state = state
103self._work_class = work_class
104
105def __getattr__(self, name: str) -> Any:
106try:
107return getattr(self._state, name)
108except AttributeError:
109# The name isn't in the state, so check if it's a callable or a property
110attribute = inspect.getattr_static(self._work_class, name)
111if callable(attribute):
112attribute = attribute.__get__(self, self._work_class)
113return attribute
114if isinstance(attribute, (staticmethod, property)):
115return attribute.__get__(self, self._work_class)
116
117# Look for the name in the instance (e.g. for private variables)
118return object.__getattribute__(self, name)
119
120def __setattr__(self, name: str, value: Any) -> None:
121if name in ["_state", "_work_class"]:
122return object.__setattr__(self, name, value)
123
124if hasattr(self._state, name):
125return setattr(self._state, name, value)
126return object.__setattr__(self, name, value)
127
128
129def _reduce_to_component_scope(state: AppState, component_name: str) -> AppState:
130"""Given the app state, this utility traverses down to the level of the given component name."""
131component_name_parts = component_name.split(".")[1:] # exclude root
132component_state = state
133for part in component_name_parts:
134component_state = getattr(component_state, part)
135return component_state
136
137
138def _get_work_class() -> Callable:
139"""Import the work class specified in the environment."""
140work_name = os.environ["LIGHTNING_WORK"]
141work_module_file = os.environ["LIGHTNING_WORK_MODULE_FILE"]
142module = pydoc.importfile(work_module_file)
143return getattr(module, work_name)
144
145
146def _build_model(work: ServeStreamlit) -> None:
147import streamlit as st
148
149# Build the model (once per session, equivalent to gradio when enable_queue is Flase)
150if "_model" not in st.session_state:
151with st.spinner("Building model..."):
152st.session_state["_model"] = work.build_model()
153
154work._model = st.session_state["_model"]
155
156
157def _main() -> None:
158# Get the AppState
159app_state = AppState(plugin=StreamLitStatePlugin())
160work_state = _reduce_to_component_scope(app_state, os.environ["LIGHTNING_COMPONENT_NAME"])
161
162# Create the patched work
163work_class = _get_work_class()
164patched_work = _PatchedWork(work_state, work_class)
165
166# Build and attach the model
167_build_model(patched_work)
168
169# Render
170patched_work.render()
171
172
173if __name__ == "__main__":
174_main()
175