pytorch-lightning

Форк
0
163 строки · 6.7 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Profiler to check if there are any bottlenecks in your code."""
15

16
import logging
17
import os
18
import time
19
from collections import defaultdict
20
from pathlib import Path
21
from typing import Dict, List, Optional, Tuple, Union
22

23
import torch
24
from typing_extensions import override
25

26
from lightning.pytorch.profilers.profiler import Profiler
27

28
log = logging.getLogger(__name__)
29

30
_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float]
31
_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED]
32
_TABLE_ROW = Tuple[str, float, float]
33
_TABLE_DATA = List[_TABLE_ROW]
34

35

36
class SimpleProfiler(Profiler):
37
    """This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action
38
    and the total time spent over the entire training run."""
39

40
    def __init__(
41
        self,
42
        dirpath: Optional[Union[str, Path]] = None,
43
        filename: Optional[str] = None,
44
        extended: bool = True,
45
    ) -> None:
46
        """
47
        Args:
48
            dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
49
                ``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
50
                will be used.
51

52
            filename: If present, filename where the profiler results will be saved instead of printing to stdout.
53
                The ``.txt`` extension will be used automatically.
54

55
            extended: If ``True``, adds extra columns representing number of calls and percentage of total time spent on
56
                respective action.
57

58
        Raises:
59
            ValueError:
60
                If you attempt to start an action which has already started, or
61
                if you attempt to stop recording an action which was never started.
62
        """
63
        super().__init__(dirpath=dirpath, filename=filename)
64
        self.current_actions: Dict[str, float] = {}
65
        self.recorded_durations: Dict = defaultdict(list)
66
        self.extended = extended
67
        self.start_time = time.monotonic()
68

69
    @override
70
    def start(self, action_name: str) -> None:
71
        if action_name in self.current_actions:
72
            raise ValueError(f"Attempted to start {action_name} which has already started.")
73
        self.current_actions[action_name] = time.monotonic()
74

75
    @override
76
    def stop(self, action_name: str) -> None:
77
        end_time = time.monotonic()
78
        if action_name not in self.current_actions:
79
            raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
80
        start_time = self.current_actions.pop(action_name)
81
        duration = end_time - start_time
82
        self.recorded_durations[action_name].append(duration)
83

84
    def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]:
85
        total_duration = time.monotonic() - self.start_time
86
        report = []
87

88
        for a, d in self.recorded_durations.items():
89
            d_tensor = torch.tensor(d)
90
            len_d = len(d)
91
            sum_d = torch.sum(d_tensor).item()
92
            percentage_d = 100.0 * sum_d / total_duration
93

94
            report.append((a, sum_d / len_d, len_d, sum_d, percentage_d))
95

96
        report.sort(key=lambda x: x[4], reverse=True)
97
        total_calls = sum(x[2] for x in report)
98
        return report, total_calls, total_duration
99

100
    def _make_report(self) -> _TABLE_DATA:
101
        report = []
102
        for action, d in self.recorded_durations.items():
103
            d_tensor = torch.tensor(d)
104
            sum_d = torch.sum(d_tensor).item()
105

106
            report.append((action, sum_d / len(d), sum_d))
107

108
        report.sort(key=lambda x: x[1], reverse=True)
109
        return report
110

111
    @override
112
    def summary(self) -> str:
113
        sep = os.linesep
114
        output_string = ""
115
        if self._stage is not None:
116
            output_string += f"{self._stage.upper()} "
117
        output_string += f"Profiler Report{sep}"
118

119
        if self.extended:
120
            if len(self.recorded_durations) > 0:
121
                max_key = max(len(k) for k in self.recorded_durations)
122

123
                def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
124
                    row = f"{sep}|  {action:<{max_key}s}\t|  {mean:<15}\t|"
125
                    row += f"  {num_calls:<15}\t|  {total:<15}\t|  {per:<15}\t|"
126
                    return row
127

128
                header_string = log_row_extended(
129
                    "Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %"
130
                )
131
                output_string_len = len(header_string.expandtabs())
132
                sep_lines = f"{sep}{'-' * output_string_len}"
133
                output_string += sep_lines + header_string + sep_lines
134
                report_extended: _TABLE_DATA_EXTENDED
135
                report_extended, total_calls, total_duration = self._make_report_extended()
136
                output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
137
                output_string += sep_lines
138
                for action, mean_duration, num_calls, total_duration, duration_per in report_extended:
139
                    output_string += log_row_extended(
140
                        action,
141
                        f"{mean_duration:.5}",
142
                        f"{num_calls}",
143
                        f"{total_duration:.5}",
144
                        f"{duration_per:.5}",
145
                    )
146
                output_string += sep_lines
147
        else:
148
            max_key = max(len(k) for k in self.recorded_durations)
149

150
            def log_row(action: str, mean: str, total: str) -> str:
151
                return f"{sep}|  {action:<{max_key}s}\t|  {mean:<15}\t|  {total:<15}\t|"
152

153
            header_string = log_row("Action", "Mean duration (s)", "Total time (s)")
154
            output_string_len = len(header_string.expandtabs())
155
            sep_lines = f"{sep}{'-' * output_string_len}"
156
            output_string += sep_lines + header_string + sep_lines
157
            report = self._make_report()
158

159
            for action, mean_duration, total_duration in report:
160
                output_string += log_row(action, f"{mean_duration:.5}", f"{total_duration:.5}")
161
            output_string += sep_lines
162
        output_string += sep
163
        return output_string
164

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.