pytorch-lightning
148 строк · 3.9 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""General utilities."""
15
16import functools
17import os
18import platform
19import sys
20import warnings
21from typing import Any, List, Union
22
23from lightning_utilities.core.imports import module_available
24from packaging.requirements import Marker, Requirement
25
26try:
27from importlib import metadata
28except ImportError:
29# Python < 3.8
30import importlib_metadata as metadata # type: ignore
31
32
33def _get_extras(extras: str) -> str:
34"""Get the given extras as a space delimited string.
35
36Used by the platform to install cloud extras in the cloud.
37
38"""
39from lightning.app import __package_name__
40
41requirements = {r: Requirement(r) for r in metadata.requires(__package_name__)}
42marker = Marker(f'extra == "{extras}"')
43requirements = [r for r, req in requirements.items() if str(req.marker) == str(marker)]
44
45if requirements:
46requirements = [f"'{r.split(';')[0].strip()}'" for r in requirements]
47return " ".join(requirements)
48return ""
49
50
51def requires(module_paths: Union[str, List]):
52if not isinstance(module_paths, list):
53module_paths = [module_paths]
54
55def decorator(func):
56@functools.wraps(func)
57def wrapper(*args: Any, **kwargs: Any):
58unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)]
59if any(unavailable_modules):
60is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
61msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
62if is_lit_testing:
63warnings.warn(msg)
64else:
65raise ModuleNotFoundError(msg)
66return func(*args, **kwargs)
67
68return wrapper
69
70return decorator
71
72
73# TODO: Automatically detect dependencies
74def _is_redis_available() -> bool:
75return module_available("redis")
76
77
78def _is_torch_available() -> bool:
79return module_available("torch")
80
81
82def _is_pytorch_lightning_available() -> bool:
83return module_available("lightning.pytorch")
84
85
86def _is_torchvision_available() -> bool:
87return module_available("torchvision")
88
89
90def _is_json_argparse_available() -> bool:
91return module_available("jsonargparse")
92
93
94def _is_streamlit_available() -> bool:
95return module_available("streamlit")
96
97
98def _is_param_available() -> bool:
99return module_available("param")
100
101
102def _is_streamlit_tensorboard_available() -> bool:
103return module_available("streamlit_tensorboard")
104
105
106def _is_gradio_available() -> bool:
107return module_available("gradio")
108
109
110def _is_lightning_flash_available() -> bool:
111return module_available("flash")
112
113
114def _is_pil_available() -> bool:
115return module_available("PIL")
116
117
118def _is_numpy_available() -> bool:
119return module_available("numpy")
120
121
122def _is_docker_available() -> bool:
123return module_available("docker")
124
125
126def _is_jinja2_available() -> bool:
127return module_available("jinja2")
128
129
130def _is_playwright_available() -> bool:
131return module_available("playwright")
132
133
134def _is_s3fs_available() -> bool:
135return module_available("s3fs")
136
137
138def _is_sqlmodel_available() -> bool:
139return module_available("sqlmodel")
140
141
142def _is_aiohttp_available() -> bool:
143return module_available("aiohttp")
144
145
146_CLOUD_TEST_RUN = bool(os.getenv("CLOUD", False))
147_IS_WINDOWS = platform.system() == "Windows"
148_IS_MACOS = sys.platform == "darwin"
149