10
from dataclasses import dataclass, field
11
from importlib import __import__
12
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13
from weakref import WeakSet
15
log = logging.getLogger(__name__)
17
# This is a synthetic logger which doesn't correspond to an actual logger,
18
# but handles all of our "tracing" logging, which is structured and doesn't go
19
# to stderr but always goes to a dedicated log file. We don't put these
20
# loggers in the classic module hierarchy, because we don't want a suppression
21
# of logs to also cause a trace to get suppressed (traces typically are not
22
# collected, unless we are in prod, in which case they always are collected.)
24
# TODO: Maybe we should allow for some sub-hierarchy so you can control which
25
# traces you want to collect, for performance reasons.
27
# See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
28
trace_log = logging.getLogger("torch.__trace")
30
DEFAULT_LOG_LEVEL = logging.WARNING
31
LOG_ENV_VAR = "TORCH_LOGS"
32
LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
33
LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
34
TRACE_ENV_VAR = "TORCH_TRACE"
39
# shorthand name to log qualified name
40
# Note: this only contains loggers registered
42
# e.g. "dynamo" -> "torch._dynamo"
43
log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
45
# artifact logger qualified names,
46
# this is populated lazily, as calls to getArtifactLogger
47
# currently formatted as <module>.__<artifact_name>
48
# e.g. "torch._dynamo.convert_frame.__guards"
49
artifact_log_qnames: Set[str] = field(default_factory=set)
51
# child logs of registered logs if specified via open
52
# registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
53
# these need to be tracked so their levels can be reset properly
54
# e.g. "torch._dynamo.output_graph"
55
child_log_qnames: Set[str] = field(default_factory=set)
57
# artifact names, populated by register_artifact
59
artifact_names: Set[str] = field(default_factory=set)
61
# Artifacts that should be visible by default in the error message
62
visible_artifacts: Set[str] = field(default_factory=set)
64
# A short description of each artifact
65
artifact_descriptions: Dict[str, str] = field(default_factory=dict)
67
# artifacts which are not displayed unless explicitly named in the
68
# settings. Ex. output_code is NOT displayed even if the inductor
69
# log level is set to DEBUG. It must be explicitly named in the settings
70
off_by_default_artifact_names: Set[str] = field(default_factory=set)
72
# logging format string for artifacts
73
artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
75
def is_artifact(self, name):
76
return name in self.artifact_names
78
def is_log(self, alias):
79
return alias in self.log_alias_to_log_qnames
81
# register a log with an alias
82
def register_log(self, alias, log_qnames: Union[str, List[str]]):
83
if isinstance(log_qnames, str):
84
log_qnames = [log_qnames]
85
self.log_alias_to_log_qnames[alias] = log_qnames
87
# register an artifact name
88
def register_artifact_name(
89
self, name, description, visible, off_by_default, log_format
91
self.artifact_names.add(name)
93
self.visible_artifacts.add(name)
94
self.artifact_descriptions[name] = description
96
# if off by default, don't enable it
97
# when log_name's log_level is set to DEBUG
99
self.off_by_default_artifact_names.add(name)
101
if log_format is not None:
102
self.artifact_log_formatters[name] = logging.Formatter(log_format)
104
# register the qualified name of an artifact log
105
# this is needed to know which logs need to be reset
106
# whenever the log_state is changed
107
def register_artifact_log(self, artifact_log_qname):
108
self.artifact_log_qnames.add(artifact_log_qname)
110
def register_child_log(self, log_qname):
111
self.child_log_qnames.add(log_qname)
113
# flattens all the qnames together (TODO: consider memoizing?)
114
def get_log_qnames(self) -> Set[str]:
117
for qnames in self.log_alias_to_log_qnames.values()
121
def get_artifact_log_qnames(self):
122
return set(self.artifact_log_qnames)
124
def get_child_log_qnames(self):
125
return set(self.child_log_qnames)
127
def is_off_by_default(self, artifact_qname):
128
return artifact_qname in self.off_by_default_artifact_names
133
# qualified log names -> currently set log level
134
log_qname_to_level: Dict[str, str] = field(default_factory=dict)
136
# the set of currently enabled artifacts
137
artifact_names: Set[str] = field(default_factory=set)
139
def enable_artifact(self, artifact_name):
140
self.artifact_names.add(artifact_name)
142
def is_artifact_enabled(self, name):
143
return name in self.artifact_names
145
def enable_log(self, log_qnames, log_level):
146
if isinstance(log_qnames, str):
147
log_qnames = [log_qnames]
148
for log_qname in log_qnames:
149
self.log_qname_to_level[log_qname] = log_level
151
def get_log_level_pairs(self):
152
"""Returns all qualified module names for which the user requested
153
explicit logging settings.
157
This function used to return all loggers, regardless of whether
158
or not the user specified them or not; it now only returns logs
159
which were explicitly mentioned by the user (and torch, which
160
always is implicitly requested when we initialize our logging
163
return self.log_qname_to_level.items()
166
self.log_qname_to_level.clear()
167
self.artifact_names.clear()
170
log_registry = LogRegistry()
171
log_state = LogState()
173
# sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
175
"dynamo": logging.DEBUG,
176
"aot": logging.DEBUG,
177
"inductor": logging.DEBUG,
179
"graph_breaks": True,
182
"dynamic": logging.INFO,
188
all: Optional[int] = None,
189
dynamo: Optional[int] = None,
190
aot: Optional[int] = None,
191
autograd: Optional[int] = None,
192
dynamic: Optional[int] = None,
193
inductor: Optional[int] = None,
194
distributed: Optional[int] = None,
195
dist_c10d: Optional[int] = None,
196
dist_ddp: Optional[int] = None,
197
dist_fsdp: Optional[int] = None,
198
onnx: Optional[int] = None,
199
bytecode: bool = False,
200
aot_graphs: bool = False,
201
aot_joint_graph: bool = False,
202
ddp_graphs: bool = False,
204
graph_code: bool = False,
205
graph_breaks: bool = False,
206
graph_sizes: bool = False,
207
guards: bool = False,
208
recompiles: bool = False,
209
recompiles_verbose: bool = False,
210
trace_source: bool = False,
211
trace_call: bool = False,
212
output_code: bool = False,
213
schedule: bool = False,
214
perf_hints: bool = False,
215
post_grad_graphs: bool = False,
216
onnx_diagnostics: bool = False,
217
fusion: bool = False,
218
overlap: bool = False,
219
export: Optional[int] = None,
220
modules: Optional[Dict[str, Union[int, bool]]] = None,
221
cudagraphs: bool = False,
222
sym_node: bool = False,
225
Sets the log level for individual components and toggles individual log
228
.. warning:: This feature is a prototype and may have compatibility
229
breaking changes in the future.
231
.. note:: The ``TORCH_LOGS`` environment variable has complete precedence
232
over this function, so if it was set, this function does nothing.
234
A component is a set of related features in PyTorch. All of the log
235
messages emitted from a given component have their own log levels. If the
236
log level of a particular message has priority greater than or equal to its
237
component's log level setting, it is emitted. Otherwise, it is suppressed.
238
This allows you to, for instance, silence large groups of log messages that
239
are not relevant to you and increase verbosity of logs for components that
240
are relevant. The expected log level values, ordered from highest to lowest
243
* ``logging.CRITICAL``
245
* ``logging.WARNING``
250
See documentation for the Python ``logging`` module for more information on
251
log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
253
An artifact is a particular type of log message. Each artifact is assigned
254
to a parent component. A component can emit many different kinds of
255
artifacts. In general, an artifact is emitted if either its corresponding
256
setting in the argument list below is turned on or if its parent component
257
is set to a log level less than or equal to the log level of the artifact.
260
all (:class:`Optional[int]`):
261
The default log level for all components. Default: ``logging.WARN``
263
dynamo (:class:`Optional[int]`):
264
The log level for the TorchDynamo component. Default: ``logging.WARN``
266
aot (:class:`Optional[int]`):
267
The log level for the AOTAutograd component. Default: ``logging.WARN``
269
autograd (:class:`Optional[int]`):
270
The log level for autograd. Default: ``logging.WARN``
272
inductor (:class:`Optional[int]`):
273
The log level for the TorchInductor component. Default: ``logging.WARN``
275
dynamic (:class:`Optional[int]`):
276
The log level for dynamic shapes. Default: ``logging.WARN``
278
distributed (:class:`Optional[int]`):
279
Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
280
Default: ``logging.WARN``
282
dist_c10d (:class:`Optional[int]`):
283
Whether to log c10d communication operations related debug info in PyTorch Distributed components.
284
Default: ``logging.WARN``
286
dist_ddp (:class:`Optional[int]`):
287
Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
288
Default: ``logging.WARN``
290
dist_fsdp (:class:`Optional[int]`):
291
Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
292
Default: ``logging.WARN``
294
onnx (:class:`Optional[int]`):
295
The log level for the ONNX exporter component. Default: ``logging.WARN``
297
bytecode (:class:`bool`):
298
Whether to emit the original and generated bytecode from TorchDynamo.
301
aot_graphs (:class:`bool`):
302
Whether to emit the graphs generated by AOTAutograd. Default: ``False``
304
aot_joint_graph (:class:`bool`):
305
Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
307
inductor (:class:`Optional[int]`):
308
Whether to log information from inductor cudagraphs. Default: ``logging.WARN``
310
ddp_graphs (:class:`bool`):
311
Whether to emit graphs generated by DDPOptimizer. Default: ``False``
313
graph (:class:`bool`):
314
Whether to emit the graph captured by TorchDynamo in tabular format.
317
graph_code (:class:`bool`):
318
Whether to emit the python source of the graph captured by TorchDynamo.
321
graph_breaks (:class:`bool`):
322
Whether to emit the graph breaks encountered by TorchDynamo.
325
graph_sizes (:class:`bool`):
326
Whether to emit tensor sizes of the graph captured by TorchDynamo.
329
guards (:class:`bool`):
330
Whether to emit the guards generated by TorchDynamo for each compiled
331
function. Default: ``False``
333
recompiles (:class:`bool`):
334
Whether to emit a guard failure reason and message every time
335
TorchDynamo recompiles a function. Default: ``False``
337
recompiles_verbose (:class:`bool`):
338
Whether to emit all guard failure reasons when TorchDynamo recompiles
339
a function, even those that are not actually run. Default: ``False``
341
trace_source (:class:`bool`):
342
Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
344
trace_call (:class:`bool`):
345
Whether to emit detailed line location when TorchDynamo creates an FX node
346
corresponding to function call. Python 3.11+ only. Default: ``False``
348
output_code (:class:`bool`):
349
Whether to emit the TorchInductor output code. Default: ``False``
351
schedule (:class:`bool`):
352
Whether to emit the TorchInductor schedule. Default: ``False``
354
perf_hints (:class:`bool`):
355
Whether to emit the TorchInductor perf hints. Default: ``False``
357
post_grad_graphs (:class:`bool`):
358
Whether to emit the graphs generated by after post grad passes. Default: ``False``
360
onnx_diagnostics (:class:`bool`):
361
Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
363
fusion (:class:`bool`):
364
Whether to emit detailed Inductor fusion decisions. Default: ``False``
366
overlap (:class:`bool`):
367
Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
369
sym_node (:class:`bool`):
370
Whether to emit debug info for various SymNode opterations. Default: ``False``
372
export (:class:`Optional[int]`):
373
The log level for export. Default: ``logging.WARN``
376
This argument provides an alternate way to specify the above log
377
component and artifact settings, in the format of a keyword args
378
dictionary given as a single argument. There are two cases
379
where this is useful (1) if a new log component or artifact has
380
been registered but a keyword argument for it has not been added
381
to this function and (2) if the log level for an unregistered module
382
needs to be set. This can be done by providing the fully-qualified module
383
name as the key, with the log level as the value. Default: ``None``
388
>>> # xdoctest: +SKIP
391
# The following changes the "dynamo" component to emit DEBUG-level
392
# logs, and to emit "graph_code" artifacts.
394
>>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
396
# The following enables the logs for a different module
398
>>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
400
# ignore if env var is set
401
if LOG_ENV_VAR in os.environ:
403
"Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
409
modules = modules or {}
411
def _set_logs(**kwargs):
412
for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
416
if log_registry.is_artifact(alias):
417
if not isinstance(val, bool):
419
f"Expected bool to enable artifact {alias}, received {val}"
423
log_state.enable_artifact(alias)
424
elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
425
if val not in logging._levelToName:
427
f"Unrecognized log level for log {alias}: {val}, valid level values "
428
f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
431
log_state.enable_log(
432
log_registry.log_alias_to_log_qnames.get(alias, alias), val
436
f"Unrecognized log or artifact name passed to set_logs: {alias}"
449
aot_graphs=aot_graphs,
450
aot_joint_graph=aot_joint_graph,
451
ddp_graphs=ddp_graphs,
452
distributed=distributed,
457
graph_code=graph_code,
458
graph_breaks=graph_breaks,
459
graph_sizes=graph_sizes,
461
recompiles=recompiles,
462
recompiles_verbose=recompiles_verbose,
463
trace_source=trace_source,
464
trace_call=trace_call,
465
output_code=output_code,
467
perf_hints=perf_hints,
468
post_grad_graphs=post_grad_graphs,
470
onnx_diagnostics=onnx_diagnostics,
475
cudagraphs=cudagraphs,
481
Returns: a list of all registered loggers
483
return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
486
def register_log(setting_name, log_name):
488
Enables a log to be controlled by the env var and user API with the setting_name
490
setting_name: the shorthand name used in the env var and user API
491
log_name: the log name that the setting_name is associated with
493
log_registry.register_log(setting_name, log_name)
496
def register_artifact(
497
setting_name, description, visible=False, off_by_default=False, log_format=None
500
Enables an artifact to be controlled by the env var and user API with name
502
setting_name: the shorthand name used in the env var and user API
503
description: A description of what this outputs
504
visible: Whether it gets suggested to users by default
505
off_by_default: whether this artifact should be logged when the ancestor loggers
506
are enabled at level DEBUG
508
log_registry.register_artifact_name(
509
setting_name, description, visible, off_by_default, log_format
513
def getArtifactLogger(module_qname, artifact_name):
514
if artifact_name not in log_registry.artifact_names:
516
f"Artifact name: {repr(artifact_name)} not registered,"
517
f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
519
qname = module_qname + f".__{artifact_name}"
520
log = logging.getLogger(qname)
521
log.artifact_name = artifact_name # type: ignore[attr-defined]
522
log_registry.register_artifact_log(qname)
523
configure_artifact_log(log)
527
INCR_VERBOSITY_CHAR = "+"
528
DECR_VERBOSITY_CHAR = "-"
531
+ "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
536
def configure_artifact_log(log):
537
# If the artifact is off by default, then it should only be logged when explicitly
538
# enabled; set propagate to False so that this artifact is not propagated
539
# to its ancestor logger
540
if log_registry.is_off_by_default(log.artifact_name):
541
log.propagate = False
543
# enable artifact logging when explicitly enabled
544
if log_state.is_artifact_enabled(log.artifact_name):
545
log.setLevel(logging.DEBUG)
549
# match a comma separated list of loggable names (whitespace allowed after commas)
550
def _gen_settings_regex():
551
return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
554
def _validate_settings(settings):
555
return re.fullmatch(_gen_settings_regex(), settings) is not None
558
def help_message(verbose=False):
559
def pad_to(s, length=30):
560
assert len(s) <= length
561
return s + " " * (length - len(s))
564
printed_artifacts = log_registry.artifact_names
566
printed_artifacts = log_registry.visible_artifacts
569
heading = "All registered names"
571
heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
574
+ sorted(log_registry.log_alias_to_log_qnames.keys())
577
f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
578
for name in printed_artifacts
582
setting_info = " " + "\n ".join(lines)
585
TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
586
logging.DEBUG and AOT to logging.INFO
588
TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
589
logging.ERROR and TorchInductor to logging.DEBUG
591
TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
593
TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
594
to logging.DEBUG and enable the schedule artifact
596
TORCH_LOGS="+some.random.module,schedule" will set the log level of
597
some.random.module to logging.DEBUG and enable the schedule artifact
599
TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
600
string will set the output format
601
Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
602
"filename" and "name".
604
TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
605
well. This is useful when the output is long.
606
""" # flake8: noqa: B950
617
def _invalid_settings_err_msg(settings, verbose=False):
618
valid_settings = ", ".join(
620
+ list(log_registry.log_alias_to_log_qnames.keys())
621
+ list(log_registry.artifact_names)
624
Invalid log settings: {settings}, must be a comma separated list of fully
625
qualified module names, registered log names or registered artifact names.
626
For more info on various settings, try TORCH_LOGS="help"
634
def _parse_log_settings(settings):
638
if settings == "help":
639
raise ValueError(help_message(verbose=False))
640
elif settings == "+help":
641
raise ValueError(help_message(verbose=True))
642
if not _validate_settings(settings):
643
raise ValueError(_invalid_settings_err_msg(settings))
645
settings = re.sub(r"\s+", "", settings)
646
log_names = settings.split(",")
648
def get_name_level_pair(name):
649
clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
650
clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
652
if name[0] == INCR_VERBOSITY_CHAR:
653
level = logging.DEBUG
654
elif name[0] == DECR_VERBOSITY_CHAR:
655
level = logging.ERROR
659
return clean_name, level
661
log_state = LogState()
663
for name in log_names:
664
name, level = get_name_level_pair(name)
669
if log_registry.is_log(name):
670
assert level is not None
671
log_qnames = log_registry.log_alias_to_log_qnames[name]
672
log_state.enable_log(log_qnames, level)
673
elif log_registry.is_artifact(name):
674
log_state.enable_artifact(name)
675
elif _is_valid_module(name):
676
if not _has_registered_parent(name):
677
log_registry.register_log(name, name)
679
log_registry.register_child_log(name)
680
log_state.enable_log(name, level)
682
raise ValueError(_invalid_settings_err_msg(settings))
687
def _is_valid_module(qname):
695
def _update_log_state_from_env():
697
log_setting = os.environ.get(LOG_ENV_VAR, None)
698
if log_setting is not None:
699
log_state = _parse_log_settings(log_setting)
702
def _has_registered_parent(log_qname):
703
cur_log = logging.getLogger(log_qname)
705
registered_log_qnames = log_registry.get_log_qnames()
707
while cur_log.parent:
708
if cur_log.name in registered_log_qnames:
710
cur_log = cur_log.parent
715
# apply custom formats to artifacts when necessary
716
class TorchLogsFormatter(logging.Formatter):
717
def __init__(self, *, trace: bool = False):
719
self._is_trace = trace
721
def format(self, record):
722
artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
723
if artifact_name is not None:
724
artifact_formatter = log_registry.artifact_log_formatters.get(
727
if artifact_formatter is not None:
728
return artifact_formatter.format(record)
730
record.message = record.getMessage()
731
record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
733
# exception handling - copied from logging.Formatter.format
736
# Cache the traceback text to avoid converting it multiple times
737
# (it's constant anyway)
738
if not record.exc_text:
739
record.exc_text = self.formatException(record.exc_info)
743
s = s + record.exc_text
744
if record.stack_info:
747
s = s + self.formatStack(record.stack_info)
749
record.rankprefix = ""
750
if not self._is_trace and dist.is_available() and dist.is_initialized():
751
record.rankprefix = f"[rank{dist.get_rank()}]:"
756
and (trace_id := torch._guards.CompileContext.current_trace_id())
759
record.traceid = f" [{trace_id}]"
761
glog_level_to_abbr = {
762
"DEBUG": "V", # V is for VERBOSE in glog
769
shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
771
record.artifactprefix = ""
772
if artifact_name is not None:
773
record.artifactprefix = f" [__{artifact_name}]"
776
f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.thread} "
777
f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:"
778
f"{record.lineno}]{record.traceid}{record.artifactprefix}"
782
r = f"{prefix} {json.dumps(record.metadata)}"
783
if record.payload is not None:
784
r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
787
lines = s.split("\n")
788
return "\n".join(f"{prefix} {l}" for l in lines)
791
def _default_formatter():
792
fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
794
return TorchLogsFormatter()
796
if fmt in ("short", "basic"):
797
fmt = logging.BASIC_FORMAT
798
return logging.Formatter(fmt)
801
DEFAULT_FORMATTER = _default_formatter()
804
def _setup_handlers(create_handler_fn, log):
805
debug_handler = _track_handler(create_handler_fn())
806
debug_handler.setFormatter(DEFAULT_FORMATTER)
807
debug_handler.setLevel(logging.DEBUG)
808
log.addHandler(debug_handler)
811
handlers = WeakSet() # type: ignore[var-annotated]
814
# mark handlers that we've created
815
# so we don't modify user handlers
816
def _track_handler(handler):
817
handlers.add(handler)
821
def _is_torch_handler(handler):
822
return handler in handlers
825
# clears all torch handlers on specified loggers
826
def _clear_handlers(log):
827
to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
828
for handler in to_remove:
829
log.removeHandler(handler)
833
# reset all registered logs
834
for log_qname in log_registry.get_log_qnames():
835
log = logging.getLogger(log_qname)
836
log.setLevel(logging.WARNING)
837
log.propagate = False
840
# reset all artifact and child logs
841
for artifact_log_qname in itertools.chain(
842
log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
844
log = logging.getLogger(artifact_log_qname)
845
log.setLevel(logging.NOTSET)
848
trace_log.propagate = False
849
_clear_handlers(trace_log)
856
def _set_log_state(state):
861
def _init_logs(log_file_name=None):
863
_update_log_state_from_env()
865
out = os.environ.get(LOG_OUT_ENV_VAR, None)
869
# First, reset all known (registered) loggers to NOTSET, so that they
870
# respect their parent log level
871
for log_qname in log_registry.get_log_qnames():
872
# But not the top level torch level: this defaults to WARNING so
873
# that our log messages don't leak to the lower levels
874
if log_qname == "torch":
876
log = logging.getLogger(log_qname)
877
log.setLevel(logging.NOTSET)
879
# Now, for all loggers which the user requested to have non-standard
880
# logging behavior, modify their log levels
881
for log_qname, level in log_state.get_log_level_pairs():
882
log = logging.getLogger(log_qname)
885
# Finally, setup handlers for all registered loggers
886
for log_qname in log_registry.get_log_qnames():
887
log = logging.getLogger(log_qname)
889
logging.StreamHandler,
893
if log_file_name is not None:
895
lambda: logging.FileHandler(log_file_name),
899
# configure artifact loggers, note: this must happen last
900
# since the levels of ancestor loggers are taken into account
901
for artifact_log_qname in log_registry.get_artifact_log_qnames():
902
log = logging.getLogger(artifact_log_qname)
903
configure_artifact_log(log)
905
# Setup handler for the special trace_log, with different default
908
# TODO: Automatically initialize this in Tupperware environment to point
909
# to /logs/dedicated_logs_XXX
910
trace_file_name = os.environ.get(TRACE_ENV_VAR, None)
911
handler: Optional[logging.Handler] = None
912
if trace_file_name is not None:
913
handler = logging.FileHandler(trace_file_name)
915
# This handler may remove itself if we are not actually in an FB
916
# environment. This allows us to defer actually initializing it until
917
# we actually need to log anything. This is important because JK
918
# initializes a C++ singleton, which will pork our process if we
920
handler = LazyFbTraceHandler()
921
# This log is ALWAYS at debug level. We will additionally test if there
922
# are any handlers before deciding to actually call logging on this. Do
924
trace_log.setLevel(logging.DEBUG)
925
trace_log_handler = _track_handler(handler)
926
trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
927
trace_log.addHandler(trace_log_handler)
930
class LazyFbTraceHandler(logging.StreamHandler):
931
"""Like FileHandler, but the file is allocated lazily only upon the first log message"""
934
# This is implemented in the same way that delay is implemented on
936
logging.Handler.__init__(self)
938
self._builtin_open = open
940
# cloned from FileHandler in cpython
951
if hasattr(stream, "close"):
954
# Issue #19523: call unconditionally to
955
# prevent a handler leak when delay is set
956
# Also see Issue #42378: we also rely on
957
# self._closed being set to True there
958
logging.StreamHandler.close(self)
962
def emit(self, record):
963
if self.stream is None:
964
# TODO: more robust is_fbcode test
967
TRACE_LOG_DIR = "/logs"
968
open_func = self._builtin_open
971
import torch.version as torch_version
973
if hasattr(torch_version, "git_version"):
974
log.info("LazyFbTraceHandler: disabled because not fbcode")
975
elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
977
"LazyFbTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
979
elif not os.path.exists(TRACE_LOG_DIR):
981
"LazyFbTraceHandler: disabled because %s does not exist",
984
elif not os.access(TRACE_LOG_DIR, os.W_OK):
986
"LazyFbTraceHandler: disabled because %s is not writeable",
994
if dist.is_available() and dist.is_initialized():
995
ranksuffix = f"rank_{dist.get_rank()}_"
996
self.stream = tempfile.NamedTemporaryFile(
999
prefix=f"dedicated_log_torch_trace_{ranksuffix}",
1003
log.info("LazyFbTraceHandler: logging to %s", self.stream.name)
1005
# We go poof, remove and no-op
1006
trace_log.removeHandler(self)
1009
super().emit(record)
1012
@functools.lru_cache(None)
1013
def warning_once(logger_obj, *args, **kwargs):
1015
This function is similar to `logger.warning()`, but will emit the warning with the same message only once
1016
Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
1017
The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
1018
another type of cache that includes the caller frame information in the hashing function.
1020
logger_obj.warning(*args, **kwargs)
1024
def __init__(self, func, *args, **kwargs):
1027
self.kwargs = kwargs
1030
return self.func(*self.args, **self.kwargs)
1033
def trace_structured(
1035
# NB: metadata expected to be dict so adding more info is forward compatible
1036
# Tuple[str, int] is a special case for string interning
1037
metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
1039
payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
1040
suppress_context: bool = False,
1043
metadata is an arbitrary JSON compatible struct, but it's expected to not be
1044
too long (e.g., less than 1MB)
1046
payload is an arbitrary string, which can be arbitrarily long (but expected to have
1047
newlines so no lines are too long)
1049
assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"]
1052
), f"metadata_fn should be callable, but got {type(metadata_fn)}"
1055
), f"payload_fn should be callable, but got {type(payload_fn)}"
1056
# trace_log never propagates and is ALWAYS DEBUG, so also check that there
1057
# are handlers instead of checking the log level
1058
if trace_log.handlers:
1059
record: Dict[str, object] = {}
1060
record[name] = metadata_fn()
1061
if not suppress_context:
1062
# TODO: Actually, the rank probably should just be emitted once at
1063
# the top, and not repeatedly spammed in all the logs, since it
1064
# never changes and we assume no interleaving
1065
if dist.is_available() and dist.is_initialized():
1066
record["rank"] = dist.get_rank()
1068
trace_id := torch._guards.CompileContext.current_trace_id()
1070
record["frame_id"] = trace_id.compile_id.frame_id
1071
record["frame_compile_id"] = trace_id.compile_id.frame_compile_id
1072
record["attempt"] = trace_id.attempt
1073
payload = payload_fn()
1074
if payload is not None:
1075
if not isinstance(payload, str):
1076
if isinstance(payload, list):
1077
# special case to look better
1078
payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
1080
# force newlines so we are unlikely to overflow line limit
1081
payload = json.dumps(payload, indent=0)
1083
h.update(payload.encode("utf-8"))
1084
record["has_payload"] = h.hexdigest()
1086
"", extra={"metadata": record, "payload": payload}, stacklevel=2
1091
import torch._utils_internal
1092
import torch.distributed as dist