pytorch-lightning
99 строк · 3.7 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Profiler to check if there are any bottlenecks in your code."""
15
16import cProfile17import io18import logging19import pstats20from pathlib import Path21from typing import Dict, Optional, Tuple, Union22
23from typing_extensions import override24
25from lightning.pytorch.profilers.profiler import Profiler26
27log = logging.getLogger(__name__)28
29
30class AdvancedProfiler(Profiler):31"""This profiler uses Python's cProfiler to record more detailed information about time spent in each function call32recorded during a given action.
33
34The output is quite verbose and you should only use this if you want very detailed reports.
35
36"""
37
38def __init__(39self,40dirpath: Optional[Union[str, Path]] = None,41filename: Optional[str] = None,42line_count_restriction: float = 1.0,43) -> None:44"""45Args:
46dirpath: Directory path for the ``filename``. If ``dirpath`` is ``None`` but ``filename`` is present, the
47``trainer.log_dir`` (from :class:`~lightning.pytorch.loggers.tensorboard.TensorBoardLogger`)
48will be used.
49
50filename: If present, filename where the profiler results will be saved instead of printing to stdout.
51The ``.txt`` extension will be used automatically.
52
53line_count_restriction: this can be used to limit the number of functions
54reported for each action. either an integer (to select a count of lines),
55or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
56
57Raises:
58ValueError:
59If you attempt to stop recording an action which was never started.
60"""
61super().__init__(dirpath=dirpath, filename=filename)62self.profiled_actions: Dict[str, cProfile.Profile] = {}63self.line_count_restriction = line_count_restriction64
65@override66def start(self, action_name: str) -> None:67if action_name not in self.profiled_actions:68self.profiled_actions[action_name] = cProfile.Profile()69self.profiled_actions[action_name].enable()70
71@override72def stop(self, action_name: str) -> None:73pr = self.profiled_actions.get(action_name)74if pr is None:75raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")76pr.disable()77
78@override79def summary(self) -> str:80recorded_stats = {}81for action_name, pr in self.profiled_actions.items():82s = io.StringIO()83ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats("cumulative")84ps.print_stats(self.line_count_restriction)85recorded_stats[action_name] = s.getvalue()86return self._stats_to_str(recorded_stats)87
88@override89def teardown(self, stage: Optional[str]) -> None:90super().teardown(stage=stage)91self.profiled_actions = {}92
93def __reduce__(self) -> Tuple:94# avoids `TypeError: cannot pickle 'cProfile.Profile' object`95return (96self.__class__,97(),98{"dirpath": self.dirpath, "filename": self.filename, "line_count_restriction": self.line_count_restriction},99)100