pytorch-lightning
213 строк · 7.7 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import base64
16import json
17import os
18import pathlib
19from dataclasses import dataclass
20from enum import Enum
21from time import sleep
22from typing import Optional
23from urllib.parse import urlencode
24
25import click
26import requests
27import uvicorn
28from fastapi import FastAPI, Query, Request
29from starlette.background import BackgroundTask
30from starlette.responses import RedirectResponse
31
32from lightning.app.core.constants import LIGHTNING_CREDENTIAL_PATH, get_lightning_cloud_url
33from lightning.app.utilities.app_helpers import Logger
34from lightning.app.utilities.network import find_free_network_port
35
36logger = Logger(__name__)
37
38
39class Keys(Enum):
40USERNAME = "LIGHTNING_USERNAME"
41USER_ID = "LIGHTNING_USER_ID"
42API_KEY = "LIGHTNING_API_KEY"
43
44@property
45def suffix(self):
46return self.value.lstrip("LIGHTNING_").lower()
47
48
49@dataclass
50class Auth:
51username: Optional[str] = None
52user_id: Optional[str] = None
53api_key: Optional[str] = None
54
55secrets_file = pathlib.Path(LIGHTNING_CREDENTIAL_PATH)
56
57def load(self) -> bool:
58"""Load credentials from disk and update properties with credentials.
59
60Returns
61----------
62True if credentials are available.
63
64"""
65if not self.secrets_file.exists():
66logger.debug("Credentials file not found.")
67return False
68with self.secrets_file.open() as creds:
69credentials = json.load(creds)
70for key in Keys:
71setattr(self, key.suffix, credentials.get(key.suffix, None))
72return True
73
74def save(self, token: str = "", user_id: str = "", api_key: str = "", username: str = "") -> None:
75"""Save credentials to disk."""
76self.secrets_file.parent.mkdir(exist_ok=True, parents=True)
77with self.secrets_file.open("w") as f:
78json.dump(
79{
80f"{Keys.USERNAME.suffix}": username,
81f"{Keys.USER_ID.suffix}": user_id,
82f"{Keys.API_KEY.suffix}": api_key,
83},
84f,
85)
86
87self.username = username
88self.user_id = user_id
89self.api_key = api_key
90logger.debug("credentials saved successfully")
91
92def clear(self) -> None:
93"""Remove credentials from disk."""
94if self.secrets_file.exists():
95self.secrets_file.unlink()
96for key in Keys:
97setattr(self, key.suffix, None)
98logger.debug("credentials removed successfully")
99
100@property
101def auth_header(self) -> Optional[str]:
102"""Authentication header used by lightning-cloud client."""
103if self.api_key:
104token = f"{self.user_id}:{self.api_key}"
105return f"Basic {base64.b64encode(token.encode('ascii')).decode('ascii')}" # E501
106raise AttributeError(
107"Authentication Failed, no authentication header available. "
108"This is most likely a bug in the LightningCloud Framework"
109)
110
111def _run_server(self) -> None:
112"""Start a server to complete authentication."""
113AuthServer().login_with_browser(self)
114
115def authenticate(self) -> Optional[str]:
116"""Performs end to end authentication flow.
117
118Returns
119----------
120authorization header to use when authentication completes.
121
122"""
123if not self.load():
124# First try to authenticate from env
125for key in Keys:
126setattr(self, key.suffix, os.environ.get(key.value, None))
127
128if self.user_id and self.api_key:
129self.save("", self.user_id, self.api_key, self.user_id)
130logger.info("Credentials loaded from environment variables")
131return self.auth_header
132if self.api_key or self.user_id:
133raise ValueError(
134"To use env vars for authentication both "
135f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set."
136)
137
138logger.debug("failed to load credentials, opening browser to get new.")
139self._run_server()
140return self.auth_header
141
142if self.user_id and self.api_key:
143return self.auth_header
144
145raise ValueError(
146"We couldn't find any credentials linked to your account. "
147"Please try logging in using the CLI command `lightning_app login`"
148)
149
150
151class AuthServer:
152@staticmethod
153def get_auth_url(port: int) -> str:
154redirect_uri = f"http://localhost:{port}/login-complete"
155params = urlencode({"redirectTo": redirect_uri})
156return f"{get_lightning_cloud_url()}/sign-in?{params}"
157
158def login_with_browser(self, auth: Auth) -> None:
159app = FastAPI()
160port = find_free_network_port()
161url = self.get_auth_url(port)
162
163try:
164# check if server is reachable or catch any network errors
165requests.head(url)
166except requests.ConnectionError as ex:
167raise requests.ConnectionError(
168f"No internet connection available. Please connect to a stable internet connection \n{ex}" # E501
169)
170except requests.RequestException as ex:
171raise requests.RequestException(
172f"An error occurred with the request. Please report this issue to Lightning Team \n{ex}" # E501
173)
174
175logger.info(
176"\nAttempting to automatically open the login page in your default browser.\n"
177'If the browser does not open, navigate to the "Keys" tab on your Lightning AI profile page:\n\n'
178f"{get_lightning_cloud_url()}/me/keys\n\n"
179'Copy the "Headless CLI Login" command, and execute it in your terminal.\n'
180)
181click.launch(url)
182
183@app.get("/login-complete")
184async def save_token(request: Request, token="", key="", user_id: str = Query("", alias="userID")):
185async def stop_server_once_request_is_done():
186while not await request.is_disconnected():
187sleep(0.25)
188server.should_exit = True
189
190if not token:
191logger.warn(
192"Login Failed. This is most likely because you're using an older version of the CLI. \n" # E501
193"Please try to update the CLI or open an issue with this information \n" # E501
194f"expected token in {request.query_params.items()}"
195)
196return RedirectResponse(
197url=f"{get_lightning_cloud_url()}/cli-login-failed",
198background=BackgroundTask(stop_server_once_request_is_done),
199)
200
201auth.save(token=token, username=user_id, user_id=user_id, api_key=key)
202logger.info("Login Successful")
203
204# Include the credentials in the redirect so that UI will also be logged in
205params = urlencode({"token": token, "key": key, "userID": user_id})
206
207return RedirectResponse(
208url=f"{get_lightning_cloud_url()}/cli-login-successful?{params}",
209background=BackgroundTask(stop_server_once_request_is_done),
210)
211
212server = uvicorn.Server(config=uvicorn.Config(app, port=port, log_level="error"))
213server.run()
214