pytorch-lightning

Форк
0
148 строк · 3.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""General utilities."""
15

16
import functools
17
import os
18
import platform
19
import sys
20
import warnings
21
from typing import Any, List, Union
22

23
from lightning_utilities.core.imports import module_available
24
from packaging.requirements import Marker, Requirement
25

26
try:
27
    from importlib import metadata
28
except ImportError:
29
    # Python < 3.8
30
    import importlib_metadata as metadata  # type: ignore
31

32

33
def _get_extras(extras: str) -> str:
34
    """Get the given extras as a space delimited string.
35

36
    Used by the platform to install cloud extras in the cloud.
37

38
    """
39
    from lightning.app import __package_name__
40

41
    requirements = {r: Requirement(r) for r in metadata.requires(__package_name__)}
42
    marker = Marker(f'extra == "{extras}"')
43
    requirements = [r for r, req in requirements.items() if str(req.marker) == str(marker)]
44

45
    if requirements:
46
        requirements = [f"'{r.split(';')[0].strip()}'" for r in requirements]
47
        return " ".join(requirements)
48
    return ""
49

50

51
def requires(module_paths: Union[str, List]):
52
    if not isinstance(module_paths, list):
53
        module_paths = [module_paths]
54

55
    def decorator(func):
56
        @functools.wraps(func)
57
        def wrapper(*args: Any, **kwargs: Any):
58
            unavailable_modules = [f"'{module}'" for module in module_paths if not module_available(module)]
59
            if any(unavailable_modules):
60
                is_lit_testing = bool(int(os.getenv("LIGHTING_TESTING", "0")))
61
                msg = f"Required dependencies not available. Please run: pip install {' '.join(unavailable_modules)}"
62
                if is_lit_testing:
63
                    warnings.warn(msg)
64
                else:
65
                    raise ModuleNotFoundError(msg)
66
            return func(*args, **kwargs)
67

68
        return wrapper
69

70
    return decorator
71

72

73
# TODO: Automatically detect dependencies
74
def _is_redis_available() -> bool:
75
    return module_available("redis")
76

77

78
def _is_torch_available() -> bool:
79
    return module_available("torch")
80

81

82
def _is_pytorch_lightning_available() -> bool:
83
    return module_available("lightning.pytorch")
84

85

86
def _is_torchvision_available() -> bool:
87
    return module_available("torchvision")
88

89

90
def _is_json_argparse_available() -> bool:
91
    return module_available("jsonargparse")
92

93

94
def _is_streamlit_available() -> bool:
95
    return module_available("streamlit")
96

97

98
def _is_param_available() -> bool:
99
    return module_available("param")
100

101

102
def _is_streamlit_tensorboard_available() -> bool:
103
    return module_available("streamlit_tensorboard")
104

105

106
def _is_gradio_available() -> bool:
107
    return module_available("gradio")
108

109

110
def _is_lightning_flash_available() -> bool:
111
    return module_available("flash")
112

113

114
def _is_pil_available() -> bool:
115
    return module_available("PIL")
116

117

118
def _is_numpy_available() -> bool:
119
    return module_available("numpy")
120

121

122
def _is_docker_available() -> bool:
123
    return module_available("docker")
124

125

126
def _is_jinja2_available() -> bool:
127
    return module_available("jinja2")
128

129

130
def _is_playwright_available() -> bool:
131
    return module_available("playwright")
132

133

134
def _is_s3fs_available() -> bool:
135
    return module_available("s3fs")
136

137

138
def _is_sqlmodel_available() -> bool:
139
    return module_available("sqlmodel")
140

141

142
def _is_aiohttp_available() -> bool:
143
    return module_available("aiohttp")
144

145

146
_CLOUD_TEST_RUN = bool(os.getenv("CLOUD", False))
147
_IS_WINDOWS = platform.system() == "Windows"
148
_IS_MACOS = sys.platform == "darwin"
149

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.