1
# Copyright The Lightning AI team.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
19
from datetime import datetime
20
from typing import Optional
22
from lightning.app import _PROJECT_ROOT, LightningWork
23
from lightning.app.storage.path import _shared_local_mount_path
24
from lightning.app.utilities.imports import _is_docker_available, _is_jinja2_available, requires
26
if _is_docker_available():
28
from docker.models.containers import Container
30
if _is_jinja2_available():
36
def __init__(self, file: str, work: LightningWork, queue_id: str, create_base: bool = False):
39
self.queue_id = queue_id
40
self.image: Optional[str] = None
42
self._create_base_container()
43
self._create_work_container()
45
def _create_base_container(self) -> None:
46
# 1. Get base container
47
container_base = f"{_PROJECT_ROOT}/dockers/Dockerfile.base.cpu"
48
destination_path = os.path.join(_PROJECT_ROOT, "Dockerfile")
50
# 2. Copy the base Dockerfile within the Lightning project
51
shutil.copy(container_base, destination_path)
53
# 3. Build the docker image.
54
os.system("docker build . --tag thomaschaton/base")
56
# 4. Clean the copied base Dockerfile.
57
os.remove(destination_path)
59
def _create_work_container(self) -> None:
60
# 1. Get work container.
61
source_path = os.path.join(_PROJECT_ROOT, "dockers/Dockerfile.jinja")
62
destination_path = os.path.join(_PROJECT_ROOT, "Dockerfile")
63
work_destination_path = os.path.join(_PROJECT_ROOT, "work.p")
66
with open(work_destination_path, "wb") as f:
67
pickle.dump(self.work, f)
69
# 3. Load Lightning requirements.
70
with open(source_path) as f:
71
template = jinja2.Template(f.read())
73
# Get the work local build spec.
74
requirements = self.work.local_build_config.requirements
76
# Render template with the requirements.
77
dockerfile_str = template.render(
78
requirements=" ".join(requirements),
79
redis_host="host.docker.internal" if sys.platform == "darwin" else "127.0.0.1",
82
with open(destination_path, "w") as f:
83
f.write(dockerfile_str)
85
# 4. Build the container.
86
self.image = f"work-{self.work.__class__.__qualname__.lower()}"
87
os.system(f"docker build . --tag {self.image}")
89
# 5. Clean the leftover files.
90
os.remove(destination_path)
91
os.remove(work_destination_path)
93
def run(self) -> None:
96
# 1. Run the work container and launch the work.
97
client = docker.DockerClient(base_url="unix://var/run/docker.sock")
98
self.container: Container = client.containers.run(
105
ports=[url.split(":")[-1] for url in self.work._urls if url],
106
volumes=[f"{str(_shared_local_mount_path())}:/home/.shared"],
107
command=f"python -m lightning run work {self.file} --work_name={self.work.name} --queue_id {self.queue_id}",
108
environment={"SHARED_MOUNT_DIRECTORY": "/home/.shared"},
112
# 2. Check the log and exit when ``Starting WorkRunner`` is found.
113
for line in self.container.logs(stream=True):
114
line = str(line.strip())
116
if "command not found" in line:
117
raise RuntimeError("The container wasn't properly executed.")
118
if "Starting WorkRunner" in line:
121
def get_container_logs(self) -> str:
122
"""Returns the logs of the container produced until now."""
123
return "".join([chr(c) for c in self.container.logs(until=datetime.now())])
125
def kill(self) -> None:
126
"""Kill the container."""
127
self.container.kill()