pytorch-lightning
1from enum import Enum
2
3from lightning.app.core.constants import APP_SERVER_IN_CLOUD
4from lightning.app.runners.backends.backend import Backend
5from lightning.app.runners.backends.cloud import CloudBackend
6from lightning.app.runners.backends.docker import DockerBackend
7from lightning.app.runners.backends.mp_process import CloudMultiProcessingBackend, MultiProcessingBackend
8
9
10class BackendType(Enum):
11MULTIPROCESSING = "multiprocessing"
12DOCKER = "docker"
13CLOUD = "cloud"
14
15def get_backend(self, entrypoint_file: str) -> "Backend":
16if self == BackendType.MULTIPROCESSING:
17if APP_SERVER_IN_CLOUD:
18return CloudMultiProcessingBackend(entrypoint_file)
19return MultiProcessingBackend(entrypoint_file)
20if self == BackendType.DOCKER:
21return DockerBackend(entrypoint_file)
22if self == BackendType.CLOUD:
23return CloudBackend(entrypoint_file)
24raise ValueError("Unknown client type")
25