pytorch-lightning

Форк
0
127 строк · 4.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import pickle
17
import shutil
18
import sys
19
from datetime import datetime
20
from typing import Optional
21

22
from lightning.app import _PROJECT_ROOT, LightningWork
23
from lightning.app.storage.path import _shared_local_mount_path
24
from lightning.app.utilities.imports import _is_docker_available, _is_jinja2_available, requires
25

26
if _is_docker_available():
27
    import docker
28
    from docker.models.containers import Container
29

30
if _is_jinja2_available():
31
    import jinja2
32

33

34
class DockerRunner:
35
    @requires("docker")
36
    def __init__(self, file: str, work: LightningWork, queue_id: str, create_base: bool = False):
37
        self.file = file
38
        self.work = work
39
        self.queue_id = queue_id
40
        self.image: Optional[str] = None
41
        if create_base:
42
            self._create_base_container()
43
        self._create_work_container()
44

45
    def _create_base_container(self) -> None:
46
        # 1. Get base container
47
        container_base = f"{_PROJECT_ROOT}/dockers/Dockerfile.base.cpu"
48
        destination_path = os.path.join(_PROJECT_ROOT, "Dockerfile")
49

50
        # 2. Copy the base Dockerfile within the Lightning project
51
        shutil.copy(container_base, destination_path)
52

53
        # 3. Build the docker image.
54
        os.system("docker build . --tag thomaschaton/base")
55

56
        # 4. Clean the copied base Dockerfile.
57
        os.remove(destination_path)
58

59
    def _create_work_container(self) -> None:
60
        # 1. Get work container.
61
        source_path = os.path.join(_PROJECT_ROOT, "dockers/Dockerfile.jinja")
62
        destination_path = os.path.join(_PROJECT_ROOT, "Dockerfile")
63
        work_destination_path = os.path.join(_PROJECT_ROOT, "work.p")
64

65
        # 2. Pickle the work.
66
        with open(work_destination_path, "wb") as f:
67
            pickle.dump(self.work, f)
68

69
        # 3. Load Lightning requirements.
70
        with open(source_path) as f:
71
            template = jinja2.Template(f.read())
72

73
        # Get the work local build spec.
74
        requirements = self.work.local_build_config.requirements
75

76
        # Render template with the requirements.
77
        dockerfile_str = template.render(
78
            requirements=" ".join(requirements),
79
            redis_host="host.docker.internal" if sys.platform == "darwin" else "127.0.0.1",
80
        )
81

82
        with open(destination_path, "w") as f:
83
            f.write(dockerfile_str)
84

85
        # 4. Build the container.
86
        self.image = f"work-{self.work.__class__.__qualname__.lower()}"
87
        os.system(f"docker build . --tag {self.image}")
88

89
        # 5. Clean the leftover files.
90
        os.remove(destination_path)
91
        os.remove(work_destination_path)
92

93
    def run(self) -> None:
94
        assert self.image
95

96
        # 1. Run the work container and launch the work.
97
        client = docker.DockerClient(base_url="unix://var/run/docker.sock")
98
        self.container: Container = client.containers.run(
99
            image=self.image,
100
            shm_size="10G",
101
            stderr=True,
102
            stdout=True,
103
            stdin_open=True,
104
            detach=True,
105
            ports=[url.split(":")[-1] for url in self.work._urls if url],
106
            volumes=[f"{str(_shared_local_mount_path())}:/home/.shared"],
107
            command=f"python -m lightning run work {self.file} --work_name={self.work.name} --queue_id {self.queue_id}",
108
            environment={"SHARED_MOUNT_DIRECTORY": "/home/.shared"},
109
            network_mode="host",
110
        )
111

112
        # 2. Check the log and exit when ``Starting WorkRunner`` is found.
113
        for line in self.container.logs(stream=True):
114
            line = str(line.strip())
115
            print(line)
116
            if "command not found" in line:
117
                raise RuntimeError("The container wasn't properly executed.")
118
            if "Starting WorkRunner" in line:
119
                break
120

121
    def get_container_logs(self) -> str:
122
        """Returns the logs of the container produced until now."""
123
        return "".join([chr(c) for c in self.container.logs(until=datetime.now())])
124

125
    def kill(self) -> None:
126
        """Kill the container."""
127
        self.container.kill()
128

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.