pytorch-lightning

Форк
0
82 строки · 3.0 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import logging
15
from typing import Dict
16

17
from typing_extensions import override
18

19
from lightning.fabric.accelerators.xla import _XLA_AVAILABLE
20
from lightning.pytorch.profilers.profiler import Profiler
21

22
log = logging.getLogger(__name__)
23

24

25
class XLAProfiler(Profiler):
26
    STEP_FUNCTIONS = {"validation_step", "test_step", "predict_step"}
27
    RECORD_FUNCTIONS = {
28
        "training_step",
29
        "backward",
30
        "validation_step",
31
        "test_step",
32
        "predict_step",
33
    }
34

35
    def __init__(self, port: int = 9012) -> None:
36
        """XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU
37
        performance tools.
38

39
        Args:
40
            port: the port to start the profiler server on. An exception is
41
                raised if the provided port is invalid or busy.
42

43
        """
44
        if not _XLA_AVAILABLE:
45
            raise ModuleNotFoundError(str(_XLA_AVAILABLE))
46
        super().__init__(dirpath=None, filename=None)
47
        self.port = port
48
        self._recording_map: Dict = {}
49
        self._step_recoding_map: Dict = {}
50
        self._start_trace: bool = False
51

52
    @override
53
    def start(self, action_name: str) -> None:
54
        import torch_xla.debug.profiler as xp
55

56
        # The action name is formatted as '[TYPE]{class name}.{hook name}'
57
        # Example: [LightningModule]BoringModel.training_step
58
        if action_name.split(".")[-1] in self.RECORD_FUNCTIONS:
59
            if not self._start_trace:
60
                self.server = xp.start_server(self.port)
61
                self._start_trace = True
62

63
            if action_name.split(".")[-1] in self.STEP_FUNCTIONS:
64
                step = self._get_step_num(action_name)
65
                recording = xp.StepTrace(action_name, step_num=step)
66
            else:
67
                recording = xp.Trace(action_name)
68
            recording.__enter__()
69
            self._recording_map[action_name] = recording
70

71
    @override
72
    def stop(self, action_name: str) -> None:
73
        if action_name in self._recording_map:
74
            self._recording_map[action_name].__exit__(None, None, None)
75
            del self._recording_map[action_name]
76

77
    def _get_step_num(self, action_name: str) -> int:
78
        if action_name not in self._step_recoding_map:
79
            self._step_recoding_map[action_name] = 1
80
        else:
81
            self._step_recoding_map[action_name] += 1
82
        return self._step_recoding_map[action_name]
83

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.