pytorch-lightning
163 строки · 6.7 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Profiler to check if there are any bottlenecks in your code."""
15
16import logging
17import os
18import time
19from collections import defaultdict
20from pathlib import Path
21from typing import Dict, List, Optional, Tuple, Union
22
23import torch
24from typing_extensions import override
25
26from lightning.pytorch.profilers.profiler import Profiler
27
28log = logging.getLogger(__name__)
29
30_TABLE_ROW_EXTENDED = Tuple[str, float, int, float, float]
31_TABLE_DATA_EXTENDED = List[_TABLE_ROW_EXTENDED]
32_TABLE_ROW = Tuple[str, float, float]
33_TABLE_DATA = List[_TABLE_ROW]
34
35
36class SimpleProfiler(Profiler):
37"""This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action
38and the total time spent over the entire training run."""
39
40def __init__(
41self,
42dirpath: Optional[Union[str, Path]] = None,
43filename: Optional[str] = None,
44extended: bool = True,
45) -> None:
46"""
47Args:
48dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
49``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
50will be used.
51
52filename: If present, filename where the profiler results will be saved instead of printing to stdout.
53The ``.txt`` extension will be used automatically.
54
55extended: If ``True``, adds extra columns representing number of calls and percentage of total time spent on
56respective action.
57
58Raises:
59ValueError:
60If you attempt to start an action which has already started, or
61if you attempt to stop recording an action which was never started.
62"""
63super().__init__(dirpath=dirpath, filename=filename)
64self.current_actions: Dict[str, float] = {}
65self.recorded_durations: Dict = defaultdict(list)
66self.extended = extended
67self.start_time = time.monotonic()
68
69@override
70def start(self, action_name: str) -> None:
71if action_name in self.current_actions:
72raise ValueError(f"Attempted to start {action_name} which has already started.")
73self.current_actions[action_name] = time.monotonic()
74
75@override
76def stop(self, action_name: str) -> None:
77end_time = time.monotonic()
78if action_name not in self.current_actions:
79raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
80start_time = self.current_actions.pop(action_name)
81duration = end_time - start_time
82self.recorded_durations[action_name].append(duration)
83
84def _make_report_extended(self) -> Tuple[_TABLE_DATA_EXTENDED, float, float]:
85total_duration = time.monotonic() - self.start_time
86report = []
87
88for a, d in self.recorded_durations.items():
89d_tensor = torch.tensor(d)
90len_d = len(d)
91sum_d = torch.sum(d_tensor).item()
92percentage_d = 100.0 * sum_d / total_duration
93
94report.append((a, sum_d / len_d, len_d, sum_d, percentage_d))
95
96report.sort(key=lambda x: x[4], reverse=True)
97total_calls = sum(x[2] for x in report)
98return report, total_calls, total_duration
99
100def _make_report(self) -> _TABLE_DATA:
101report = []
102for action, d in self.recorded_durations.items():
103d_tensor = torch.tensor(d)
104sum_d = torch.sum(d_tensor).item()
105
106report.append((action, sum_d / len(d), sum_d))
107
108report.sort(key=lambda x: x[1], reverse=True)
109return report
110
111@override
112def summary(self) -> str:
113sep = os.linesep
114output_string = ""
115if self._stage is not None:
116output_string += f"{self._stage.upper()} "
117output_string += f"Profiler Report{sep}"
118
119if self.extended:
120if len(self.recorded_durations) > 0:
121max_key = max(len(k) for k in self.recorded_durations)
122
123def log_row_extended(action: str, mean: str, num_calls: str, total: str, per: str) -> str:
124row = f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t|"
125row += f" {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
126return row
127
128header_string = log_row_extended(
129"Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %"
130)
131output_string_len = len(header_string.expandtabs())
132sep_lines = f"{sep}{'-' * output_string_len}"
133output_string += sep_lines + header_string + sep_lines
134report_extended: _TABLE_DATA_EXTENDED
135report_extended, total_calls, total_duration = self._make_report_extended()
136output_string += log_row_extended("Total", "-", f"{total_calls:}", f"{total_duration:.5}", "100 %")
137output_string += sep_lines
138for action, mean_duration, num_calls, total_duration, duration_per in report_extended:
139output_string += log_row_extended(
140action,
141f"{mean_duration:.5}",
142f"{num_calls}",
143f"{total_duration:.5}",
144f"{duration_per:.5}",
145)
146output_string += sep_lines
147else:
148max_key = max(len(k) for k in self.recorded_durations)
149
150def log_row(action: str, mean: str, total: str) -> str:
151return f"{sep}| {action:<{max_key}s}\t| {mean:<15}\t| {total:<15}\t|"
152
153header_string = log_row("Action", "Mean duration (s)", "Total time (s)")
154output_string_len = len(header_string.expandtabs())
155sep_lines = f"{sep}{'-' * output_string_len}"
156output_string += sep_lines + header_string + sep_lines
157report = self._make_report()
158
159for action, mean_duration, total_duration in report:
160output_string += log_row(action, f"{mean_duration:.5}", f"{total_duration:.5}")
161output_string += sep_lines
162output_string += sep
163return output_string
164