pytorch

Форк
0
/
c10d_logger.py 
98 строк · 3.1 Кб
1
#!/usr/bin/env python3
2

3
# Copyright (c) Facebook, Inc. and its affiliates.
4
# All rights reserved.
5
#
6
# This source code is licensed under the BSD-style license found in the
7
# LICENSE file in the root directory of this source tree.
8

9
import functools
10
import logging
11
import time
12
from typing import Any, Callable, Dict, List, Tuple, TypeVar
13
from typing_extensions import ParamSpec
14

15
import torch
16
import torch.distributed as dist
17

18
from torch.distributed.logging_handlers import _log_handlers
19

20
__all__: List[str] = []
21

22

23
def _get_or_create_logger() -> logging.Logger:
24
    logging_handler, log_handler_name = _get_logging_handler()
25
    logger = logging.getLogger(f"c10d-{log_handler_name}")
26
    logger.setLevel(logging.DEBUG)
27
    formatter = logging.Formatter(
28
        "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
29
    )
30
    logging_handler.setFormatter(formatter)
31
    logger.propagate = False
32
    logger.addHandler(logging_handler)
33
    return logger
34

35

36
def _get_logging_handler(destination: str = "default") -> Tuple[logging.Handler, str]:
37
    log_handler = _log_handlers[destination]
38
    log_handler_name = type(log_handler).__name__
39
    return (log_handler, log_handler_name)
40

41

42
global _c10d_logger
43
_c10d_logger = _get_or_create_logger()
44

45

46
def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
47
    if dist.is_initialized():
48
        msg_dict = {
49
            "func_name": f"{func_name}",
50
            "args": f"{args}, {kwargs}",
51
            "pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}",  # type: ignore[arg-type]
52
            "backend": f"{dist.get_backend(kwargs.get('group'))}",
53
            "world_size": f"{dist.get_world_size()}",
54
            "group_size": f"{dist.get_world_size(kwargs.get('group'))}",
55
            "global_rank": f"{dist.get_rank()}",
56
            "local_rank": f"{dist.get_rank(kwargs.get('group'))}",
57
        }
58
        if msg_dict["backend"] == "nccl":
59
            nccl_version = torch.cuda.nccl.version()
60
            msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version)
61
    else:
62
        msg_dict = {
63
            "func_name": f"{func_name}",
64
            "args": f"{args}, {kwargs}",
65
        }
66
    return msg_dict
67

68
_T = TypeVar('_T')
69
_P = ParamSpec('_P')
70

71
def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
72
    @functools.wraps(func)
73
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
74
        try:
75
            return func(*args, **kwargs)
76
        except Exception as error:
77
            msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
78
            msg_dict["error"] = f"{error}"
79
            _c10d_logger.debug(msg_dict)
80
            raise
81

82
    return wrapper
83

84

85
def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
86
    @functools.wraps(func)
87
    def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
88
        t1 = time.time_ns()
89
        func_return = func(*args, **kwargs)
90
        time_spent = time.time_ns() - t1
91

92
        msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
93
        msg_dict["time_spent"] = f"{time_spent}ns"
94
        _c10d_logger.debug(msg_dict)
95

96
        return func_return
97

98
    return wrapper
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.