pytorch-lightning

Форк
0
200 строк · 6.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import abc
16
from functools import partial
17
from types import ModuleType
18
from typing import Any, List, Optional
19

20
from lightning.app.core.work import LightningWork
21
from lightning.app.utilities.imports import _is_gradio_available, requires
22

23
if _is_gradio_available():
24
    import gradio
25
else:
26
    gradio = ModuleType("gradio")
27
    gradio.themes = ModuleType("gradio.themes")
28

29
    class __DummyBase:
30
        pass
31

32
    gradio.themes.Base = __DummyBase
33

34

35
class ServeGradio(LightningWork, abc.ABC):
36
    """The ServeGradio Class enables to quickly create a ``gradio`` based UI for your LightningApp.
37

38
    In the example below, the ``ServeGradio`` is subclassed to deploy ``AnimeGANv2``.
39

40
    .. literalinclude:: ../../../../examples/app/components/serve/gradio/app.py
41
        :language: python
42

43
    The result would be the following:
44

45
    .. image:: https://pl-public-data.s3.amazonaws.com/assets_lightning/anime_gan.gif
46
        :alt: Animation showing how to AnimeGANv2 UI would looks like.
47

48
    """
49

50
    inputs: Any
51
    outputs: Any
52
    examples: Optional[List] = None
53
    enable_queue: bool = False
54
    title: Optional[str] = None
55
    description: Optional[str] = None
56

57
    _start_method = "spawn"
58

59
    def __init__(self, *args: Any, theme: Optional[gradio.themes.Base] = None, **kwargs: Any):
60
        requires("gradio")(super().__init__(*args, **kwargs))
61
        assert self.inputs
62
        assert self.outputs
63
        self._model = None
64
        self._theme = theme or ServeGradio.__get_lightning_gradio_theme()
65

66
        self.ready = False
67

68
    @property
69
    def model(self):
70
        return self._model
71

72
    @abc.abstractmethod
73
    def predict(self, *args: Any, **kwargs: Any):
74
        """Override with your logic to make a prediction."""
75

76
    @abc.abstractmethod
77
    def build_model(self) -> Any:
78
        """Override to instantiate and return your model.
79

80
        The model would be accessible under self.model
81

82
        """
83

84
    def run(self, *args: Any, **kwargs: Any):
85
        if self._model is None:
86
            self._model = self.build_model()
87
        fn = partial(self.predict, *args, **kwargs)
88
        fn.__name__ = self.predict.__name__
89
        self.ready = True
90
        gradio.Interface(
91
            fn=fn,
92
            inputs=self.inputs,
93
            outputs=self.outputs,
94
            examples=self.examples,
95
            title=self.title,
96
            description=self.description,
97
            theme=self._theme,
98
        ).launch(
99
            server_name=self.host,
100
            server_port=self.port,
101
            enable_queue=self.enable_queue,
102
        )
103

104
    def configure_layout(self) -> str:
105
        return self.url
106

107
    @staticmethod
108
    def __get_lightning_gradio_theme():
109
        return gradio.themes.Default(
110
            primary_hue=gradio.themes.Color(
111
                "#ffffff",
112
                "#e9d5ff",
113
                "#d8b4fe",
114
                "#c084fc",
115
                "#fcfcfc",
116
                "#a855f7",
117
                "#9333ea",
118
                "#8823e1",
119
                "#6b21a8",
120
                "#2c2730",
121
                "#1c1c1c",
122
            ),
123
            secondary_hue=gradio.themes.Color(
124
                "#c3a1e8",
125
                "#e9d5ff",
126
                "#d3bbec",
127
                "#c795f9",
128
                "#9174af",
129
                "#a855f7",
130
                "#9333ea",
131
                "#6700c2",
132
                "#000000",
133
                "#991ef1",
134
                "#33243d",
135
            ),
136
            neutral_hue=gradio.themes.Color(
137
                "#ede9fe",
138
                "#ddd6fe",
139
                "#c4b5fd",
140
                "#a78bfa",
141
                "#fafafa",
142
                "#8b5cf6",
143
                "#7c3aed",
144
                "#6d28d9",
145
                "#6130b0",
146
                "#8a4ce6",
147
                "#3b3348",
148
            ),
149
        ).set(
150
            body_background_fill="*primary_50",
151
            body_background_fill_dark="*primary_950",
152
            body_text_color_dark="*primary_100",
153
            body_text_size="*text_sm",
154
            body_text_color_subdued_dark="*primary_100",
155
            background_fill_primary="*primary_50",
156
            background_fill_primary_dark="*primary_950",
157
            background_fill_secondary="*primary_50",
158
            background_fill_secondary_dark="*primary_950",
159
            border_color_accent="*primary_400",
160
            border_color_accent_dark="*primary_900",
161
            border_color_primary="*primary_600",
162
            border_color_primary_dark="*primary_800",
163
            color_accent="*primary_400",
164
            color_accent_soft="*primary_300",
165
            color_accent_soft_dark="*primary_700",
166
            link_text_color="*primary_500",
167
            link_text_color_dark="*primary_50",
168
            link_text_color_active="*secondary_800",
169
            link_text_color_active_dark="*primary_500",
170
            link_text_color_hover="*primary_400",
171
            link_text_color_hover_dark="*primary_400",
172
            link_text_color_visited="*primary_500",
173
            link_text_color_visited_dark="*secondary_100",
174
            block_background_fill="*primary_50",
175
            block_background_fill_dark="*primary_900",
176
            block_border_color_dark="*primary_800",
177
            checkbox_background_color="*primary_50",
178
            checkbox_background_color_dark="*primary_50",
179
            checkbox_background_color_focus="*primary_100",
180
            checkbox_background_color_focus_dark="*primary_100",
181
            checkbox_background_color_hover="*primary_400",
182
            checkbox_background_color_hover_dark="*primary_500",
183
            checkbox_background_color_selected="*primary_300",
184
            checkbox_background_color_selected_dark="*primary_500",
185
            checkbox_border_color_dark="*primary_200",
186
            checkbox_border_radius="*radius_md",
187
            input_background_fill="*primary_50",
188
            input_background_fill_dark="*primary_900",
189
            input_radius="*radius_xxl",
190
            slider_color="*primary_600",
191
            slider_color_dark="*primary_700",
192
            button_large_radius="*radius_xxl",
193
            button_large_text_size="*text_md",
194
            button_small_radius="*radius_xxl",
195
            button_primary_background_fill_dark="*primary_800",
196
            button_primary_background_fill_hover_dark="*primary_700",
197
            button_primary_border_color_dark="*primary_800",
198
            button_secondary_background_fill="*neutral_200",
199
            button_secondary_background_fill_dark="*primary_600",
200
        )
201

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.