pytorch
98 строк · 3.1 Кб
1#!/usr/bin/env python3
2
3# Copyright (c) Facebook, Inc. and its affiliates.
4# All rights reserved.
5#
6# This source code is licensed under the BSD-style license found in the
7# LICENSE file in the root directory of this source tree.
8
9import functools
10import logging
11import time
12from typing import Any, Callable, Dict, List, Tuple, TypeVar
13from typing_extensions import ParamSpec
14
15import torch
16import torch.distributed as dist
17
18from torch.distributed.logging_handlers import _log_handlers
19
20__all__: List[str] = []
21
22
23def _get_or_create_logger() -> logging.Logger:
24logging_handler, log_handler_name = _get_logging_handler()
25logger = logging.getLogger(f"c10d-{log_handler_name}")
26logger.setLevel(logging.DEBUG)
27formatter = logging.Formatter(
28"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
29)
30logging_handler.setFormatter(formatter)
31logger.propagate = False
32logger.addHandler(logging_handler)
33return logger
34
35
36def _get_logging_handler(destination: str = "default") -> Tuple[logging.Handler, str]:
37log_handler = _log_handlers[destination]
38log_handler_name = type(log_handler).__name__
39return (log_handler, log_handler_name)
40
41
42global _c10d_logger
43_c10d_logger = _get_or_create_logger()
44
45
46def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
47if dist.is_initialized():
48msg_dict = {
49"func_name": f"{func_name}",
50"args": f"{args}, {kwargs}",
51"pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type]
52"backend": f"{dist.get_backend(kwargs.get('group'))}",
53"world_size": f"{dist.get_world_size()}",
54"group_size": f"{dist.get_world_size(kwargs.get('group'))}",
55"global_rank": f"{dist.get_rank()}",
56"local_rank": f"{dist.get_rank(kwargs.get('group'))}",
57}
58if msg_dict["backend"] == "nccl":
59nccl_version = torch.cuda.nccl.version()
60msg_dict["nccl_version"] = ".".join(str(v) for v in nccl_version)
61else:
62msg_dict = {
63"func_name": f"{func_name}",
64"args": f"{args}, {kwargs}",
65}
66return msg_dict
67
68_T = TypeVar('_T')
69_P = ParamSpec('_P')
70
71def _exception_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
72@functools.wraps(func)
73def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
74try:
75return func(*args, **kwargs)
76except Exception as error:
77msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
78msg_dict["error"] = f"{error}"
79_c10d_logger.debug(msg_dict)
80raise
81
82return wrapper
83
84
85def _time_logger(func: Callable[_P, _T]) -> Callable[_P, _T]:
86@functools.wraps(func)
87def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
88t1 = time.time_ns()
89func_return = func(*args, **kwargs)
90time_spent = time.time_ns() - t1
91
92msg_dict = _get_msg_dict(func.__name__, *args, **kwargs)
93msg_dict["time_spent"] = f"{time_spent}ns"
94_c10d_logger.debug(msg_dict)
95
96return func_return
97
98return wrapper
99