pytorch-lightning

Форк
0
151 строка · 5.3 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Profiler to check if there are any bottlenecks in your code."""
15

16
import logging
17
import os
18
from abc import ABC, abstractmethod
19
from contextlib import contextmanager
20
from pathlib import Path
21
from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union
22

23
from lightning.fabric.utilities.cloud_io import get_filesystem
24

25
log = logging.getLogger(__name__)
26

27

28
class Profiler(ABC):
29
    """If you wish to write a custom profiler, you should inherit from this class."""
30

31
    def __init__(
32
        self,
33
        dirpath: Optional[Union[str, Path]] = None,
34
        filename: Optional[str] = None,
35
    ) -> None:
36
        self.dirpath = dirpath
37
        self.filename = filename
38

39
        self._output_file: Optional[TextIO] = None
40
        self._write_stream: Optional[Callable] = None
41
        self._local_rank: Optional[int] = None
42
        self._stage: Optional[str] = None
43

44
    @abstractmethod
45
    def start(self, action_name: str) -> None:
46
        """Defines how to start recording an action."""
47

48
    @abstractmethod
49
    def stop(self, action_name: str) -> None:
50
        """Defines how to record the duration once an action is complete."""
51

52
    def summary(self) -> str:
53
        return ""
54

55
    @contextmanager
56
    def profile(self, action_name: str) -> Generator:
57
        """Yields a context manager to encapsulate the scope of a profiled action.
58

59
        Example::
60

61
            with self.profile('load training data'):
62
                # load training data code
63

64
        The profiler will start once you've entered the context and will automatically
65
        stop once you exit the code block.
66

67
        """
68
        try:
69
            self.start(action_name)
70
            yield action_name
71
        finally:
72
            self.stop(action_name)
73

74
    def _rank_zero_info(self, *args: Any, **kwargs: Any) -> None:
75
        if self._local_rank in (None, 0):
76
            log.info(*args, **kwargs)
77

78
    def _prepare_filename(
79
        self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-"
80
    ) -> str:
81
        args = []
82
        if self._stage is not None:
83
            args.append(self._stage)
84
        if self.filename:
85
            args.append(self.filename)
86
        if self._local_rank is not None:
87
            args.append(str(self._local_rank))
88
        if action_name is not None:
89
            args.append(action_name)
90
        return split_token.join(args) + extension
91

92
    def _prepare_streams(self) -> None:
93
        if self._write_stream is not None:
94
            return
95
        if self.filename and self.dirpath:
96
            filepath = os.path.join(self.dirpath, self._prepare_filename())
97
            fs = get_filesystem(filepath)
98
            fs.mkdirs(self.dirpath, exist_ok=True)
99
            file = fs.open(filepath, "a")
100
            self._output_file = file
101
            self._write_stream = file.write
102
        else:
103
            self._write_stream = self._rank_zero_info
104

105
    def describe(self) -> None:
106
        """Logs a profile report after the conclusion of run."""
107
        # users might call `describe` directly as the profilers can be used by themselves.
108
        # to allow this, we open and close the files within this function by calling `_prepare_streams` and `teardown`
109
        # manually instead of letting the `Trainer` do it through `setup` and `teardown`
110
        self._prepare_streams()
111
        summary = self.summary()
112
        if summary and self._write_stream is not None:
113
            self._write_stream(summary)
114
        if self._output_file is not None:
115
            self._output_file.flush()
116
        self.teardown(stage=self._stage)
117

118
    def _stats_to_str(self, stats: Dict[str, str]) -> str:
119
        stage = f"{self._stage.upper()} " if self._stage is not None else ""
120
        output = [stage + "Profiler Report"]
121
        for action, value in stats.items():
122
            header = f"Profile stats for: {action}"
123
            if self._local_rank is not None:
124
                header += f" rank: {self._local_rank}"
125
            output.append(header)
126
            output.append(value)
127
        return os.linesep.join(output)
128

129
    def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
130
        """Execute arbitrary pre-profiling set-up steps."""
131
        self._stage = stage
132
        self._local_rank = local_rank
133
        self.dirpath = self.dirpath or log_dir
134

135
    def teardown(self, stage: Optional[str]) -> None:
136
        """Execute arbitrary post-profiling tear-down steps.
137

138
        Closes the currently open file and stream.
139

140
        """
141
        self._write_stream = None
142
        if self._output_file is not None:
143
            self._output_file.close()
144
            self._output_file = None  # can't pickle TextIOWrapper
145

146
    def __del__(self) -> None:
147
        self.teardown(stage=self._stage)
148

149
    @property
150
    def local_rank(self) -> int:
151
        return 0 if self._local_rank is None else self._local_rank
152

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.