pytorch-lightning
151 строка · 5.3 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Profiler to check if there are any bottlenecks in your code."""
15
16import logging
17import os
18from abc import ABC, abstractmethod
19from contextlib import contextmanager
20from pathlib import Path
21from typing import Any, Callable, Dict, Generator, Optional, TextIO, Union
22
23from lightning.fabric.utilities.cloud_io import get_filesystem
24
25log = logging.getLogger(__name__)
26
27
28class Profiler(ABC):
29"""If you wish to write a custom profiler, you should inherit from this class."""
30
31def __init__(
32self,
33dirpath: Optional[Union[str, Path]] = None,
34filename: Optional[str] = None,
35) -> None:
36self.dirpath = dirpath
37self.filename = filename
38
39self._output_file: Optional[TextIO] = None
40self._write_stream: Optional[Callable] = None
41self._local_rank: Optional[int] = None
42self._stage: Optional[str] = None
43
44@abstractmethod
45def start(self, action_name: str) -> None:
46"""Defines how to start recording an action."""
47
48@abstractmethod
49def stop(self, action_name: str) -> None:
50"""Defines how to record the duration once an action is complete."""
51
52def summary(self) -> str:
53return ""
54
55@contextmanager
56def profile(self, action_name: str) -> Generator:
57"""Yields a context manager to encapsulate the scope of a profiled action.
58
59Example::
60
61with self.profile('load training data'):
62# load training data code
63
64The profiler will start once you've entered the context and will automatically
65stop once you exit the code block.
66
67"""
68try:
69self.start(action_name)
70yield action_name
71finally:
72self.stop(action_name)
73
74def _rank_zero_info(self, *args: Any, **kwargs: Any) -> None:
75if self._local_rank in (None, 0):
76log.info(*args, **kwargs)
77
78def _prepare_filename(
79self, action_name: Optional[str] = None, extension: str = ".txt", split_token: str = "-"
80) -> str:
81args = []
82if self._stage is not None:
83args.append(self._stage)
84if self.filename:
85args.append(self.filename)
86if self._local_rank is not None:
87args.append(str(self._local_rank))
88if action_name is not None:
89args.append(action_name)
90return split_token.join(args) + extension
91
92def _prepare_streams(self) -> None:
93if self._write_stream is not None:
94return
95if self.filename and self.dirpath:
96filepath = os.path.join(self.dirpath, self._prepare_filename())
97fs = get_filesystem(filepath)
98fs.mkdirs(self.dirpath, exist_ok=True)
99file = fs.open(filepath, "a")
100self._output_file = file
101self._write_stream = file.write
102else:
103self._write_stream = self._rank_zero_info
104
105def describe(self) -> None:
106"""Logs a profile report after the conclusion of run."""
107# users might call `describe` directly as the profilers can be used by themselves.
108# to allow this, we open and close the files within this function by calling `_prepare_streams` and `teardown`
109# manually instead of letting the `Trainer` do it through `setup` and `teardown`
110self._prepare_streams()
111summary = self.summary()
112if summary and self._write_stream is not None:
113self._write_stream(summary)
114if self._output_file is not None:
115self._output_file.flush()
116self.teardown(stage=self._stage)
117
118def _stats_to_str(self, stats: Dict[str, str]) -> str:
119stage = f"{self._stage.upper()} " if self._stage is not None else ""
120output = [stage + "Profiler Report"]
121for action, value in stats.items():
122header = f"Profile stats for: {action}"
123if self._local_rank is not None:
124header += f" rank: {self._local_rank}"
125output.append(header)
126output.append(value)
127return os.linesep.join(output)
128
129def setup(self, stage: str, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
130"""Execute arbitrary pre-profiling set-up steps."""
131self._stage = stage
132self._local_rank = local_rank
133self.dirpath = self.dirpath or log_dir
134
135def teardown(self, stage: Optional[str]) -> None:
136"""Execute arbitrary post-profiling tear-down steps.
137
138Closes the currently open file and stream.
139
140"""
141self._write_stream = None
142if self._output_file is not None:
143self._output_file.close()
144self._output_file = None # can't pickle TextIOWrapper
145
146def __del__(self) -> None:
147self.teardown(stage=self._stage)
148
149@property
150def local_rank(self) -> int:
151return 0 if self._local_rank is None else self._local_rank
152