pytorch-lightning

Форк
0
213 строк · 7.7 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import base64
16
import json
17
import os
18
import pathlib
19
from dataclasses import dataclass
20
from enum import Enum
21
from time import sleep
22
from typing import Optional
23
from urllib.parse import urlencode
24

25
import click
26
import requests
27
import uvicorn
28
from fastapi import FastAPI, Query, Request
29
from starlette.background import BackgroundTask
30
from starlette.responses import RedirectResponse
31

32
from lightning.app.core.constants import LIGHTNING_CREDENTIAL_PATH, get_lightning_cloud_url
33
from lightning.app.utilities.app_helpers import Logger
34
from lightning.app.utilities.network import find_free_network_port
35

36
logger = Logger(__name__)
37

38

39
class Keys(Enum):
40
    USERNAME = "LIGHTNING_USERNAME"
41
    USER_ID = "LIGHTNING_USER_ID"
42
    API_KEY = "LIGHTNING_API_KEY"
43

44
    @property
45
    def suffix(self):
46
        return self.value.lstrip("LIGHTNING_").lower()
47

48

49
@dataclass
50
class Auth:
51
    username: Optional[str] = None
52
    user_id: Optional[str] = None
53
    api_key: Optional[str] = None
54

55
    secrets_file = pathlib.Path(LIGHTNING_CREDENTIAL_PATH)
56

57
    def load(self) -> bool:
58
        """Load credentials from disk and update properties with credentials.
59

60
        Returns
61
        ----------
62
        True if credentials are available.
63

64
        """
65
        if not self.secrets_file.exists():
66
            logger.debug("Credentials file not found.")
67
            return False
68
        with self.secrets_file.open() as creds:
69
            credentials = json.load(creds)
70
            for key in Keys:
71
                setattr(self, key.suffix, credentials.get(key.suffix, None))
72
            return True
73

74
    def save(self, token: str = "", user_id: str = "", api_key: str = "", username: str = "") -> None:
75
        """Save credentials to disk."""
76
        self.secrets_file.parent.mkdir(exist_ok=True, parents=True)
77
        with self.secrets_file.open("w") as f:
78
            json.dump(
79
                {
80
                    f"{Keys.USERNAME.suffix}": username,
81
                    f"{Keys.USER_ID.suffix}": user_id,
82
                    f"{Keys.API_KEY.suffix}": api_key,
83
                },
84
                f,
85
            )
86

87
        self.username = username
88
        self.user_id = user_id
89
        self.api_key = api_key
90
        logger.debug("credentials saved successfully")
91

92
    def clear(self) -> None:
93
        """Remove credentials from disk."""
94
        if self.secrets_file.exists():
95
            self.secrets_file.unlink()
96
        for key in Keys:
97
            setattr(self, key.suffix, None)
98
        logger.debug("credentials removed successfully")
99

100
    @property
101
    def auth_header(self) -> Optional[str]:
102
        """Authentication header used by lightning-cloud client."""
103
        if self.api_key:
104
            token = f"{self.user_id}:{self.api_key}"
105
            return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}"  # E501
106
        raise AttributeError(
107
            "Authentication Failed, no authentication header available. "
108
            "This is most likely a bug in the LightningCloud Framework"
109
        )
110

111
    def _run_server(self) -> None:
112
        """Start a server to complete authentication."""
113
        AuthServer().login_with_browser(self)
114

115
    def authenticate(self) -> Optional[str]:
116
        """Performs end to end authentication flow.
117

118
        Returns
119
        ----------
120
        authorization header to use when authentication completes.
121

122
        """
123
        if not self.load():
124
            # First try to authenticate from env
125
            for key in Keys:
126
                setattr(self, key.suffix, os.environ.get(key.value, None))
127

128
            if self.user_id and self.api_key:
129
                self.save("", self.user_id, self.api_key, self.user_id)
130
                logger.info("Credentials loaded from environment variables")
131
                return self.auth_header
132
            if self.api_key or self.user_id:
133
                raise ValueError(
134
                    "To use env vars for authentication both "
135
                    f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set."
136
                )
137

138
            logger.debug("failed to load credentials, opening browser to get new.")
139
            self._run_server()
140
            return self.auth_header
141

142
        if self.user_id and self.api_key:
143
            return self.auth_header
144

145
        raise ValueError(
146
            "We couldn't find any credentials linked to your account. "
147
            "Please try logging in using the CLI command `lightning_app login`"
148
        )
149

150

151
class AuthServer:
152
    @staticmethod
153
    def get_auth_url(port: int) -> str:
154
        redirect_uri = f"http://localhost:{port}/login-complete"
155
        params = urlencode({"redirectTo": redirect_uri})
156
        return f"{get_lightning_cloud_url()}/sign-in?{params}"
157

158
    def login_with_browser(self, auth: Auth) -> None:
159
        app = FastAPI()
160
        port = find_free_network_port()
161
        url = self.get_auth_url(port)
162

163
        try:
164
            # check if server is reachable or catch any network errors
165
            requests.head(url)
166
        except requests.ConnectionError as ex:
167
            raise requests.ConnectionError(
168
                f"No internet connection available. Please connect to a stable internet connection \n{ex}"  # E501
169
            )
170
        except requests.RequestException as ex:
171
            raise requests.RequestException(
172
                f"An error occurred with the request. Please report this issue to Lightning Team \n{ex}"  # E501
173
            )
174

175
        logger.info(
176
            "\nAttempting to automatically open the login page in your default browser.\n"
177
            'If the browser does not open, navigate to the "Keys" tab on your Lightning AI profile page:\n\n'
178
            f"{get_lightning_cloud_url()}/me/keys\n\n"
179
            'Copy the "Headless CLI Login" command, and execute it in your terminal.\n'
180
        )
181
        click.launch(url)
182

183
        @app.get("/login-complete")
184
        async def save_token(request: Request, token="", key="", user_id: str = Query("", alias="userID")):
185
            async def stop_server_once_request_is_done():
186
                while not await request.is_disconnected():
187
                    sleep(0.25)
188
                server.should_exit = True
189

190
            if not token:
191
                logger.warn(
192
                    "Login Failed. This is most likely because you're using an older version of the CLI. \n"  # E501
193
                    "Please try to update the CLI or open an issue with this information \n"  # E501
194
                    f"expected token in {request.query_params.items()}"
195
                )
196
                return RedirectResponse(
197
                    url=f"{get_lightning_cloud_url()}/cli-login-failed",
198
                    background=BackgroundTask(stop_server_once_request_is_done),
199
                )
200

201
            auth.save(token=token, username=user_id, user_id=user_id, api_key=key)
202
            logger.info("Login Successful")
203

204
            # Include the credentials in the redirect so that UI will also be logged in
205
            params = urlencode({"token": token, "key": key, "userID": user_id})
206

207
            return RedirectResponse(
208
                url=f"{get_lightning_cloud_url()}/cli-login-successful?{params}",
209
                background=BackgroundTask(stop_server_once_request_is_done),
210
            )
211

212
        server = uvicorn.Server(config=uvicorn.Config(app, port=port, log_level="error"))
213
        server.run()
214

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.