pytorch-lightning
82 строки · 3.0 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging15from typing import Dict16
17from typing_extensions import override18
19from lightning.fabric.accelerators.xla import _XLA_AVAILABLE20from lightning.pytorch.profilers.profiler import Profiler21
22log = logging.getLogger(__name__)23
24
25class XLAProfiler(Profiler):26STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"}27RECORD_FUNCTIONS = {28"training_step",29"backward",30"validation_step",31"test_step",32"predict_step",33}34
35def __init__(self, port: int = 9012) -> None:36"""XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU37performance tools.
38
39Args:
40port: the port to start the profiler server on. An exception is
41raised if the provided port is invalid or busy.
42
43"""
44if not _XLA_AVAILABLE:45raise ModuleNotFoundError(str(_XLA_AVAILABLE))46super().__init__(dirpath=None, filename=None)47self.port = port48self._recording_map: Dict = {}49self._step_recoding_map: Dict = {}50self._start_trace: bool = False51
52@override53def start(self, action_name: str) -> None:54import torch_xla.debug.profiler as xp55
56# The action name is formatted as '[TYPE]{class name}.{hook name}'57# Example: [LightningModule]BoringModel.training_step58if action_name.split(".")[-1] in self.RECORD_FUNCTIONS:59if not self._start_trace:60self.server = xp.start_server(self.port)61self._start_trace = True62
63if action_name.split(".")[-1] in self.STEP_FUNCTIONS:64step = self._get_step_num(action_name)65recording = xp.StepTrace(action_name, step_num=step)66else:67recording = xp.Trace(action_name)68recording.__enter__()69self._recording_map[action_name] = recording70
71@override72def stop(self, action_name: str) -> None:73if action_name in self._recording_map:74self._recording_map[action_name].__exit__(None, None, None)75del self._recording_map[action_name]76
77def _get_step_num(self, action_name: str) -> int:78if action_name not in self._step_recoding_map:79self._step_recoding_map[action_name] = 180else:81self._step_recoding_map[action_name] += 182return self._step_recoding_map[action_name]83