pytorch

Форк
0
/
_internal.py 
1092 строки · 37.8 Кб
1
import functools
2
import hashlib
3
import itertools
4
import json
5
import logging
6
import os
7
import os.path
8
import re
9
import tempfile
10
from dataclasses import dataclass, field
11
from importlib import __import__
12
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13
from weakref import WeakSet
14

15
log = logging.getLogger(__name__)
16

17
# This is a synthetic logger which doesn't correspond to an actual logger,
18
# but handles all of our "tracing" logging, which is structured and doesn't go
19
# to stderr but always goes to a dedicated log file.  We don't put these
20
# loggers in the classic module hierarchy, because we don't want a suppression
21
# of logs to also cause a trace to get suppressed (traces typically are not
22
# collected, unless we are in prod, in which case they always are collected.)
23
#
24
# TODO: Maybe we should allow for some sub-hierarchy so you can control which
25
# traces you want to collect, for performance reasons.
26
#
27
# See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
28
trace_log = logging.getLogger("torch.__trace")
29

30
DEFAULT_LOG_LEVEL = logging.WARNING
31
LOG_ENV_VAR = "TORCH_LOGS"
32
LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
33
LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
34
TRACE_ENV_VAR = "TORCH_TRACE"
35

36

37
@dataclass
38
class LogRegistry:
39
    # shorthand name to log qualified name
40
    # Note: this only contains loggers registered
41
    # from register_log
42
    # e.g. "dynamo" -> "torch._dynamo"
43
    log_alias_to_log_qnames: Dict[str, List[str]] = field(default_factory=dict)
44

45
    # artifact logger qualified names,
46
    # this is populated lazily, as calls to getArtifactLogger
47
    # currently formatted as <module>.__<artifact_name>
48
    # e.g. "torch._dynamo.convert_frame.__guards"
49
    artifact_log_qnames: Set[str] = field(default_factory=set)
50

51
    # child logs of registered logs if specified via open
52
    # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
53
    # these need to be tracked so their levels can be reset properly
54
    # e.g. "torch._dynamo.output_graph"
55
    child_log_qnames: Set[str] = field(default_factory=set)
56

57
    # artifact names, populated by register_artifact
58
    # e.g. "guards"
59
    artifact_names: Set[str] = field(default_factory=set)
60

61
    # Artifacts that should be visible by default in the error message
62
    visible_artifacts: Set[str] = field(default_factory=set)
63

64
    # A short description of each artifact
65
    artifact_descriptions: Dict[str, str] = field(default_factory=dict)
66

67
    # artifacts which are not displayed unless explicitly named in the
68
    # settings. Ex. output_code is NOT displayed even if the inductor
69
    # log level is set to DEBUG. It must be explicitly named in the settings
70
    off_by_default_artifact_names: Set[str] = field(default_factory=set)
71

72
    # logging format string for artifacts
73
    artifact_log_formatters: Dict[str, logging.Formatter] = field(default_factory=dict)
74

75
    def is_artifact(self, name):
76
        return name in self.artifact_names
77

78
    def is_log(self, alias):
79
        return alias in self.log_alias_to_log_qnames
80

81
    # register a log with an alias
82
    def register_log(self, alias, log_qnames: Union[str, List[str]]):
83
        if isinstance(log_qnames, str):
84
            log_qnames = [log_qnames]
85
        self.log_alias_to_log_qnames[alias] = log_qnames
86

87
    # register an artifact name
88
    def register_artifact_name(
89
        self, name, description, visible, off_by_default, log_format
90
    ):
91
        self.artifact_names.add(name)
92
        if visible:
93
            self.visible_artifacts.add(name)
94
        self.artifact_descriptions[name] = description
95

96
        # if off by default, don't enable it
97
        # when log_name's log_level is set to DEBUG
98
        if off_by_default:
99
            self.off_by_default_artifact_names.add(name)
100

101
        if log_format is not None:
102
            self.artifact_log_formatters[name] = logging.Formatter(log_format)
103

104
    # register the qualified name of an artifact log
105
    # this is needed to know which logs need to be reset
106
    # whenever the log_state is changed
107
    def register_artifact_log(self, artifact_log_qname):
108
        self.artifact_log_qnames.add(artifact_log_qname)
109

110
    def register_child_log(self, log_qname):
111
        self.child_log_qnames.add(log_qname)
112

113
    # flattens all the qnames together (TODO: consider memoizing?)
114
    def get_log_qnames(self) -> Set[str]:
115
        return {
116
            qname
117
            for qnames in self.log_alias_to_log_qnames.values()
118
            for qname in qnames
119
        }
120

121
    def get_artifact_log_qnames(self):
122
        return set(self.artifact_log_qnames)
123

124
    def get_child_log_qnames(self):
125
        return set(self.child_log_qnames)
126

127
    def is_off_by_default(self, artifact_qname):
128
        return artifact_qname in self.off_by_default_artifact_names
129

130

131
@dataclass
132
class LogState:
133
    # qualified log names -> currently set log level
134
    log_qname_to_level: Dict[str, str] = field(default_factory=dict)
135

136
    # the set of currently enabled artifacts
137
    artifact_names: Set[str] = field(default_factory=set)
138

139
    def enable_artifact(self, artifact_name):
140
        self.artifact_names.add(artifact_name)
141

142
    def is_artifact_enabled(self, name):
143
        return name in self.artifact_names
144

145
    def enable_log(self, log_qnames, log_level):
146
        if isinstance(log_qnames, str):
147
            log_qnames = [log_qnames]
148
        for log_qname in log_qnames:
149
            self.log_qname_to_level[log_qname] = log_level
150

151
    def get_log_level_pairs(self):
152
        """Returns all qualified module names for which the user requested
153
        explicit logging settings.
154

155
        .. warning:
156

157
            This function used to return all loggers, regardless of whether
158
            or not the user specified them or not; it now only returns logs
159
            which were explicitly mentioned by the user (and torch, which
160
            always is implicitly requested when we initialize our logging
161
            subsystem.)
162
        """
163
        return self.log_qname_to_level.items()
164

165
    def clear(self):
166
        self.log_qname_to_level.clear()
167
        self.artifact_names.clear()
168

169

170
log_registry = LogRegistry()
171
log_state = LogState()
172

173
# sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
174
DEFAULT_LOGGING = {
175
    "dynamo": logging.DEBUG,
176
    "aot": logging.DEBUG,
177
    "inductor": logging.DEBUG,
178
    "ddp_graphs": True,
179
    "graph_breaks": True,
180
    "guards": True,
181
    "recompiles": True,
182
    "dynamic": logging.INFO,
183
}
184

185

186
def set_logs(
187
    *,
188
    all: Optional[int] = None,
189
    dynamo: Optional[int] = None,
190
    aot: Optional[int] = None,
191
    autograd: Optional[int] = None,
192
    dynamic: Optional[int] = None,
193
    inductor: Optional[int] = None,
194
    distributed: Optional[int] = None,
195
    dist_c10d: Optional[int] = None,
196
    dist_ddp: Optional[int] = None,
197
    dist_fsdp: Optional[int] = None,
198
    onnx: Optional[int] = None,
199
    bytecode: bool = False,
200
    aot_graphs: bool = False,
201
    aot_joint_graph: bool = False,
202
    ddp_graphs: bool = False,
203
    graph: bool = False,
204
    graph_code: bool = False,
205
    graph_breaks: bool = False,
206
    graph_sizes: bool = False,
207
    guards: bool = False,
208
    recompiles: bool = False,
209
    recompiles_verbose: bool = False,
210
    trace_source: bool = False,
211
    trace_call: bool = False,
212
    output_code: bool = False,
213
    schedule: bool = False,
214
    perf_hints: bool = False,
215
    post_grad_graphs: bool = False,
216
    onnx_diagnostics: bool = False,
217
    fusion: bool = False,
218
    overlap: bool = False,
219
    export: Optional[int] = None,
220
    modules: Optional[Dict[str, Union[int, bool]]] = None,
221
    cudagraphs: bool = False,
222
    sym_node: bool = False,
223
):
224
    """
225
    Sets the log level for individual components and toggles individual log
226
    artifact types.
227

228
    .. warning:: This feature is a prototype and may have compatibility
229
        breaking changes in the future.
230

231
    .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
232
        over this function, so if it was set, this function does nothing.
233

234
    A component is a set of related features in PyTorch. All of the log
235
    messages emitted from a given component have their own log levels. If the
236
    log level of a particular message has priority greater than or equal to its
237
    component's log level setting, it is emitted. Otherwise, it is suppressed.
238
    This allows you to, for instance, silence large groups of log messages that
239
    are not relevant to you and increase verbosity of logs for components that
240
    are relevant. The expected log level values, ordered from highest to lowest
241
    priority, are:
242

243
        * ``logging.CRITICAL``
244
        * ``logging.ERROR``
245
        * ``logging.WARNING``
246
        * ``logging.INFO``
247
        * ``logging.DEBUG``
248
        * ``logging.NOTSET``
249

250
    See documentation for the Python ``logging`` module for more information on
251
    log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
252

253
    An artifact is a particular type of log message. Each artifact is assigned
254
    to a parent component. A component can emit many different kinds of
255
    artifacts. In general, an artifact is emitted if either its corresponding
256
    setting in the argument list below is turned on or if its parent component
257
    is set to a log level less than or equal to the log level of the artifact.
258

259
    Keyword args:
260
        all (:class:`Optional[int]`):
261
            The default log level for all components. Default: ``logging.WARN``
262

263
        dynamo (:class:`Optional[int]`):
264
            The log level for the TorchDynamo component. Default: ``logging.WARN``
265

266
        aot (:class:`Optional[int]`):
267
            The log level for the AOTAutograd component. Default: ``logging.WARN``
268

269
        autograd (:class:`Optional[int]`):
270
            The log level for autograd. Default: ``logging.WARN``
271

272
        inductor (:class:`Optional[int]`):
273
            The log level for the TorchInductor component. Default: ``logging.WARN``
274

275
        dynamic (:class:`Optional[int]`):
276
            The log level for dynamic shapes. Default: ``logging.WARN``
277

278
        distributed (:class:`Optional[int]`):
279
            Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
280
            Default: ``logging.WARN``
281

282
        dist_c10d (:class:`Optional[int]`):
283
            Whether to log c10d communication operations related debug info in PyTorch Distributed components.
284
            Default: ``logging.WARN``
285

286
        dist_ddp (:class:`Optional[int]`):
287
            Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
288
            Default: ``logging.WARN``
289

290
        dist_fsdp (:class:`Optional[int]`):
291
            Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
292
            Default: ``logging.WARN``
293

294
        onnx (:class:`Optional[int]`):
295
            The log level for the ONNX exporter component. Default: ``logging.WARN``
296

297
        bytecode (:class:`bool`):
298
            Whether to emit the original and generated bytecode from TorchDynamo.
299
            Default: ``False``
300

301
        aot_graphs (:class:`bool`):
302
            Whether to emit the graphs generated by AOTAutograd. Default: ``False``
303

304
        aot_joint_graph (:class:`bool`):
305
            Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
306

307
        inductor (:class:`Optional[int]`):
308
            Whether to log information from inductor cudagraphs. Default: ``logging.WARN``
309

310
        ddp_graphs (:class:`bool`):
311
            Whether to emit graphs generated by DDPOptimizer. Default: ``False``
312

313
        graph (:class:`bool`):
314
            Whether to emit the graph captured by TorchDynamo in tabular format.
315
            Default: ``False``
316

317
        graph_code (:class:`bool`):
318
            Whether to emit the python source of the graph captured by TorchDynamo.
319
            Default: ``False``
320

321
        graph_breaks (:class:`bool`):
322
            Whether to emit the graph breaks encountered by TorchDynamo.
323
            Default: ``False``
324

325
        graph_sizes (:class:`bool`):
326
            Whether to emit tensor sizes of the graph captured by TorchDynamo.
327
            Default: ``False``
328

329
        guards (:class:`bool`):
330
            Whether to emit the guards generated by TorchDynamo for each compiled
331
            function. Default: ``False``
332

333
        recompiles (:class:`bool`):
334
            Whether to emit a guard failure reason and message every time
335
            TorchDynamo recompiles a function. Default: ``False``
336

337
        recompiles_verbose (:class:`bool`):
338
            Whether to emit all guard failure reasons when TorchDynamo recompiles
339
            a function, even those that are not actually run. Default: ``False``
340

341
        trace_source (:class:`bool`):
342
            Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
343

344
        trace_call (:class:`bool`):
345
            Whether to emit detailed line location when TorchDynamo creates an FX node
346
            corresponding to function call. Python 3.11+ only. Default: ``False``
347

348
        output_code (:class:`bool`):
349
            Whether to emit the TorchInductor output code. Default: ``False``
350

351
        schedule (:class:`bool`):
352
            Whether to emit the TorchInductor schedule. Default: ``False``
353

354
        perf_hints (:class:`bool`):
355
            Whether to emit the TorchInductor perf hints. Default: ``False``
356

357
        post_grad_graphs (:class:`bool`):
358
            Whether to emit the graphs generated by after post grad passes. Default: ``False``
359

360
        onnx_diagnostics (:class:`bool`):
361
            Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
362

363
        fusion (:class:`bool`):
364
            Whether to emit detailed Inductor fusion decisions. Default: ``False``
365

366
        overlap (:class:`bool`):
367
            Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
368

369
        sym_node (:class:`bool`):
370
            Whether to emit debug info for various SymNode opterations. Default: ``False``
371

372
        export (:class:`Optional[int]`):
373
            The log level for export. Default: ``logging.WARN``
374

375
        modules (dict):
376
            This argument provides an alternate way to specify the above log
377
            component and artifact settings, in the format of a keyword args
378
            dictionary given as a single argument. There are two cases
379
            where this is useful (1) if a new log component or artifact has
380
            been registered but a keyword argument for it has not been added
381
            to this function and (2) if the log level for an unregistered module
382
            needs to be set. This can be done by providing the fully-qualified module
383
            name as the key, with the log level as the value. Default: ``None``
384

385

386
    Example::
387

388
        >>> # xdoctest: +SKIP
389
        >>> import logging
390

391
        # The following changes the "dynamo" component to emit DEBUG-level
392
        # logs, and to emit "graph_code" artifacts.
393

394
        >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
395

396
        # The following enables the logs for a different module
397

398
        >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
399
    """
400
    # ignore if env var is set
401
    if LOG_ENV_VAR in os.environ:
402
        log.warning(
403
            "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
404
        )
405
        return
406

407
    log_state.clear()
408

409
    modules = modules or {}
410

411
    def _set_logs(**kwargs):
412
        for alias, val in itertools.chain(kwargs.items(), modules.items()):  # type: ignore[union-attr]
413
            if val is None:
414
                continue
415

416
            if log_registry.is_artifact(alias):
417
                if not isinstance(val, bool):
418
                    raise ValueError(
419
                        f"Expected bool to enable artifact {alias}, received {val}"
420
                    )
421

422
                if val:
423
                    log_state.enable_artifact(alias)
424
            elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
425
                if val not in logging._levelToName:
426
                    raise ValueError(
427
                        f"Unrecognized log level for log {alias}: {val}, valid level values "
428
                        f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
429
                    )
430

431
                log_state.enable_log(
432
                    log_registry.log_alias_to_log_qnames.get(alias, alias), val
433
                )
434
            else:
435
                raise ValueError(
436
                    f"Unrecognized log or artifact name passed to set_logs: {alias}"
437
                )
438

439
        _init_logs()
440

441
    _set_logs(
442
        torch=all,
443
        dynamo=dynamo,
444
        aot=aot,
445
        autograd=autograd,
446
        inductor=inductor,
447
        dynamic=dynamic,
448
        bytecode=bytecode,
449
        aot_graphs=aot_graphs,
450
        aot_joint_graph=aot_joint_graph,
451
        ddp_graphs=ddp_graphs,
452
        distributed=distributed,
453
        dist_c10d=dist_c10d,
454
        dist_ddp=dist_ddp,
455
        dist_fsdp=dist_fsdp,
456
        graph=graph,
457
        graph_code=graph_code,
458
        graph_breaks=graph_breaks,
459
        graph_sizes=graph_sizes,
460
        guards=guards,
461
        recompiles=recompiles,
462
        recompiles_verbose=recompiles_verbose,
463
        trace_source=trace_source,
464
        trace_call=trace_call,
465
        output_code=output_code,
466
        schedule=schedule,
467
        perf_hints=perf_hints,
468
        post_grad_graphs=post_grad_graphs,
469
        onnx=onnx,
470
        onnx_diagnostics=onnx_diagnostics,
471
        fusion=fusion,
472
        overlap=overlap,
473
        sym_node=sym_node,
474
        export=export,
475
        cudagraphs=cudagraphs,
476
    )
477

478

479
def get_loggers():
480
    """
481
    Returns: a list of all registered loggers
482
    """
483
    return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
484

485

486
def register_log(setting_name, log_name):
487
    """
488
    Enables a log to be controlled by the env var and user API with the setting_name
489
    Args:
490
        setting_name:  the shorthand name used in the env var and user API
491
        log_name:  the log name that the setting_name is associated with
492
    """
493
    log_registry.register_log(setting_name, log_name)
494

495

496
def register_artifact(
497
    setting_name, description, visible=False, off_by_default=False, log_format=None
498
):
499
    """
500
    Enables an artifact to be controlled by the env var and user API with name
501
    Args:
502
        setting_name: the shorthand name used in the env var and user API
503
        description: A description of what this outputs
504
        visible: Whether it gets suggested to users by default
505
        off_by_default: whether this artifact should be logged when the ancestor loggers
506
            are enabled at level DEBUG
507
    """
508
    log_registry.register_artifact_name(
509
        setting_name, description, visible, off_by_default, log_format
510
    )
511

512

513
def getArtifactLogger(module_qname, artifact_name):
514
    if artifact_name not in log_registry.artifact_names:
515
        raise ValueError(
516
            f"Artifact name: {repr(artifact_name)} not registered,"
517
            f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
518
        )
519
    qname = module_qname + f".__{artifact_name}"
520
    log = logging.getLogger(qname)
521
    log.artifact_name = artifact_name  # type: ignore[attr-defined]
522
    log_registry.register_artifact_log(qname)
523
    configure_artifact_log(log)
524
    return log
525

526

527
INCR_VERBOSITY_CHAR = "+"
528
DECR_VERBOSITY_CHAR = "-"
529
VERBOSITY_REGEX = (
530
    "("
531
    + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
532
    + "?)"
533
)
534

535

536
def configure_artifact_log(log):
537
    # If the artifact is off by default, then it should only be logged when explicitly
538
    # enabled; set propagate to False so that this artifact is not propagated
539
    # to its ancestor logger
540
    if log_registry.is_off_by_default(log.artifact_name):
541
        log.propagate = False
542

543
    # enable artifact logging when explicitly enabled
544
    if log_state.is_artifact_enabled(log.artifact_name):
545
        log.setLevel(logging.DEBUG)
546
        log.propagate = True
547

548

549
# match a comma separated list of loggable names (whitespace allowed after commas)
550
def _gen_settings_regex():
551
    return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
552

553

554
def _validate_settings(settings):
555
    return re.fullmatch(_gen_settings_regex(), settings) is not None
556

557

558
def help_message(verbose=False):
559
    def pad_to(s, length=30):
560
        assert len(s) <= length
561
        return s + " " * (length - len(s))
562

563
    if verbose:
564
        printed_artifacts = log_registry.artifact_names
565
    else:
566
        printed_artifacts = log_registry.visible_artifacts
567

568
    if verbose:
569
        heading = "All registered names"
570
    else:
571
        heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
572
    lines = (
573
        ["all"]
574
        + sorted(log_registry.log_alias_to_log_qnames.keys())
575
        + sorted(
576
            [
577
                f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
578
                for name in printed_artifacts
579
            ]
580
        )
581
    )
582
    setting_info = "  " + "\n  ".join(lines)
583
    examples = """
584
Examples:
585
  TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
586
  logging.DEBUG and AOT to logging.INFO
587

588
  TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
589
  logging.ERROR and TorchInductor to logging.DEBUG
590

591
  TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
592

593
  TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
594
  to logging.DEBUG and enable the schedule artifact
595

596
  TORCH_LOGS="+some.random.module,schedule" will set the log level of
597
  some.random.module to logging.DEBUG and enable the schedule artifact
598

599
  TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
600
  string will set the output format
601
  Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
602
  "filename" and "name".
603

604
  TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
605
  well. This is useful when the output is long.
606
"""  # flake8: noqa: B950
607
    msg = f"""
608
TORCH_LOGS Info
609
{examples}
610

611
{heading}
612
{setting_info}
613
"""
614
    return msg
615

616

617
def _invalid_settings_err_msg(settings, verbose=False):
618
    valid_settings = ", ".join(
619
        ["all"]
620
        + list(log_registry.log_alias_to_log_qnames.keys())
621
        + list(log_registry.artifact_names)
622
    )
623
    msg = f"""
624
Invalid log settings: {settings}, must be a comma separated list of fully
625
qualified module names, registered log names or registered artifact names.
626
For more info on various settings, try TORCH_LOGS="help"
627
Valid settings:
628
{valid_settings}
629
"""
630
    return msg
631

632

633
@functools.lru_cache
634
def _parse_log_settings(settings):
635
    if settings == "":
636
        return dict()
637

638
    if settings == "help":
639
        raise ValueError(help_message(verbose=False))
640
    elif settings == "+help":
641
        raise ValueError(help_message(verbose=True))
642
    if not _validate_settings(settings):
643
        raise ValueError(_invalid_settings_err_msg(settings))
644

645
    settings = re.sub(r"\s+", "", settings)
646
    log_names = settings.split(",")
647

648
    def get_name_level_pair(name):
649
        clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
650
        clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
651

652
        if name[0] == INCR_VERBOSITY_CHAR:
653
            level = logging.DEBUG
654
        elif name[0] == DECR_VERBOSITY_CHAR:
655
            level = logging.ERROR
656
        else:
657
            level = logging.INFO
658

659
        return clean_name, level
660

661
    log_state = LogState()
662

663
    for name in log_names:
664
        name, level = get_name_level_pair(name)
665

666
        if name == "all":
667
            name = "torch"
668

669
        if log_registry.is_log(name):
670
            assert level is not None
671
            log_qnames = log_registry.log_alias_to_log_qnames[name]
672
            log_state.enable_log(log_qnames, level)
673
        elif log_registry.is_artifact(name):
674
            log_state.enable_artifact(name)
675
        elif _is_valid_module(name):
676
            if not _has_registered_parent(name):
677
                log_registry.register_log(name, name)
678
            else:
679
                log_registry.register_child_log(name)
680
            log_state.enable_log(name, level)
681
        else:
682
            raise ValueError(_invalid_settings_err_msg(settings))
683

684
    return log_state
685

686

687
def _is_valid_module(qname):
688
    try:
689
        __import__(qname)
690
        return True
691
    except ImportError:
692
        return False
693

694

695
def _update_log_state_from_env():
696
    global log_state
697
    log_setting = os.environ.get(LOG_ENV_VAR, None)
698
    if log_setting is not None:
699
        log_state = _parse_log_settings(log_setting)
700

701

702
def _has_registered_parent(log_qname):
703
    cur_log = logging.getLogger(log_qname)
704

705
    registered_log_qnames = log_registry.get_log_qnames()
706

707
    while cur_log.parent:
708
        if cur_log.name in registered_log_qnames:
709
            return True
710
        cur_log = cur_log.parent
711

712
    return False
713

714

715
# apply custom formats to artifacts when necessary
716
class TorchLogsFormatter(logging.Formatter):
717
    def __init__(self, *, trace: bool = False):
718
        super().__init__()
719
        self._is_trace = trace
720

721
    def format(self, record):
722
        artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
723
        if artifact_name is not None:
724
            artifact_formatter = log_registry.artifact_log_formatters.get(
725
                artifact_name, None
726
            )
727
            if artifact_formatter is not None:
728
                return artifact_formatter.format(record)
729

730
        record.message = record.getMessage()
731
        record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
732

733
        # exception handling - copied from logging.Formatter.format
734
        s = record.message
735
        if record.exc_info:
736
            # Cache the traceback text to avoid converting it multiple times
737
            # (it's constant anyway)
738
            if not record.exc_text:
739
                record.exc_text = self.formatException(record.exc_info)
740
        if record.exc_text:
741
            if s[-1:] != "\n":
742
                s = s + "\n"
743
            s = s + record.exc_text
744
        if record.stack_info:
745
            if s[-1:] != "\n":
746
                s = s + "\n"
747
            s = s + self.formatStack(record.stack_info)
748

749
        record.rankprefix = ""
750
        if not self._is_trace and dist.is_available() and dist.is_initialized():
751
            record.rankprefix = f"[rank{dist.get_rank()}]:"
752

753
        record.traceid = ""
754
        if (
755
            not self._is_trace
756
            and (trace_id := torch._guards.CompileContext.current_trace_id())
757
            is not None
758
        ):
759
            record.traceid = f" [{trace_id}]"
760

761
        glog_level_to_abbr = {
762
            "DEBUG": "V",  # V is for VERBOSE in glog
763
            "INFO": "I",
764
            "WARNING": "W",
765
            "ERROR": "E",
766
            "CRITICAL": "C",
767
        }
768

769
        shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
770

771
        record.artifactprefix = ""
772
        if artifact_name is not None:
773
            record.artifactprefix = f" [__{artifact_name}]"
774

775
        prefix = (
776
            f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs*1000):06d} {record.thread} "
777
            f"{os.path.relpath(record.pathname, os.path.dirname(os.path.dirname(torch.__file__)))}:"
778
            f"{record.lineno}]{record.traceid}{record.artifactprefix}"
779
        )
780
        if self._is_trace:
781
            assert s == ""
782
            r = f"{prefix} {json.dumps(record.metadata)}"
783
            if record.payload is not None:
784
                r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
785
            return r
786
        else:
787
            lines = s.split("\n")
788
            return "\n".join(f"{prefix} {l}" for l in lines)
789

790

791
def _default_formatter():
792
    fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
793
    if fmt is None:
794
        return TorchLogsFormatter()
795
    else:
796
        if fmt in ("short", "basic"):
797
            fmt = logging.BASIC_FORMAT
798
        return logging.Formatter(fmt)
799

800

801
DEFAULT_FORMATTER = _default_formatter()
802

803

804
def _setup_handlers(create_handler_fn, log):
805
    debug_handler = _track_handler(create_handler_fn())
806
    debug_handler.setFormatter(DEFAULT_FORMATTER)
807
    debug_handler.setLevel(logging.DEBUG)
808
    log.addHandler(debug_handler)
809

810

811
handlers = WeakSet()  # type: ignore[var-annotated]
812

813

814
# mark handlers that we've created
815
# so we don't modify user handlers
816
def _track_handler(handler):
817
    handlers.add(handler)
818
    return handler
819

820

821
def _is_torch_handler(handler):
822
    return handler in handlers
823

824

825
# clears all torch handlers on specified loggers
826
def _clear_handlers(log):
827
    to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
828
    for handler in to_remove:
829
        log.removeHandler(handler)
830

831

832
def _reset_logs():
833
    # reset all registered logs
834
    for log_qname in log_registry.get_log_qnames():
835
        log = logging.getLogger(log_qname)
836
        log.setLevel(logging.WARNING)
837
        log.propagate = False
838
        _clear_handlers(log)
839

840
    # reset all artifact and child logs
841
    for artifact_log_qname in itertools.chain(
842
        log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
843
    ):
844
        log = logging.getLogger(artifact_log_qname)
845
        log.setLevel(logging.NOTSET)
846
        log.propagate = True
847

848
    trace_log.propagate = False
849
    _clear_handlers(trace_log)
850

851

852
def _get_log_state():
853
    return log_state
854

855

856
def _set_log_state(state):
857
    global log_state
858
    log_state = state
859

860

861
def _init_logs(log_file_name=None):
862
    _reset_logs()
863
    _update_log_state_from_env()
864

865
    out = os.environ.get(LOG_OUT_ENV_VAR, None)
866
    if out is not None:
867
        log_file_name = out
868

869
    # First, reset all known (registered) loggers to NOTSET, so that they
870
    # respect their parent log level
871
    for log_qname in log_registry.get_log_qnames():
872
        # But not the top level torch level: this defaults to WARNING so
873
        # that our log messages don't leak to the lower levels
874
        if log_qname == "torch":
875
            continue
876
        log = logging.getLogger(log_qname)
877
        log.setLevel(logging.NOTSET)
878

879
    # Now, for all loggers which the user requested to have non-standard
880
    # logging behavior, modify their log levels
881
    for log_qname, level in log_state.get_log_level_pairs():
882
        log = logging.getLogger(log_qname)
883
        log.setLevel(level)
884

885
    # Finally, setup handlers for all registered loggers
886
    for log_qname in log_registry.get_log_qnames():
887
        log = logging.getLogger(log_qname)
888
        _setup_handlers(
889
            logging.StreamHandler,
890
            log,
891
        )
892

893
        if log_file_name is not None:
894
            _setup_handlers(
895
                lambda: logging.FileHandler(log_file_name),
896
                log,
897
            )
898

899
    # configure artifact loggers, note: this must happen last
900
    # since the levels of ancestor loggers are taken into account
901
    for artifact_log_qname in log_registry.get_artifact_log_qnames():
902
        log = logging.getLogger(artifact_log_qname)
903
        configure_artifact_log(log)
904

905
    # Setup handler for the special trace_log, with different default
906
    # configuration
907
    #
908
    # TODO: Automatically initialize this in Tupperware environment to point
909
    # to /logs/dedicated_logs_XXX
910
    trace_file_name = os.environ.get(TRACE_ENV_VAR, None)
911
    handler: Optional[logging.Handler] = None
912
    if trace_file_name is not None:
913
        handler = logging.FileHandler(trace_file_name)
914
    else:
915
        # This handler may remove itself if we are not actually in an FB
916
        # environment.  This allows us to defer actually initializing it until
917
        # we actually need to log anything.  This is important because JK
918
        # initializes a C++ singleton, which will pork our process if we
919
        # subsequently fork.
920
        handler = LazyFbTraceHandler()
921
    # This log is ALWAYS at debug level.  We will additionally test if there
922
    # are any handlers before deciding to actually call logging on this.  Do
923
    # not manually call
924
    trace_log.setLevel(logging.DEBUG)
925
    trace_log_handler = _track_handler(handler)
926
    trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
927
    trace_log.addHandler(trace_log_handler)
928

929

930
class LazyFbTraceHandler(logging.StreamHandler):
931
    """Like FileHandler, but the file is allocated lazily only upon the first log message"""
932

933
    def __init__(self):
934
        # This is implemented in the same way that delay is implemented on
935
        # FileHandler
936
        logging.Handler.__init__(self)
937
        self.stream = None
938
        self._builtin_open = open
939

940
    # cloned from FileHandler in cpython
941
    def close(self):
942
        self.acquire()
943
        try:
944
            try:
945
                if self.stream:
946
                    try:
947
                        self.flush()
948
                    finally:
949
                        stream = self.stream
950
                        self.stream = None
951
                        if hasattr(stream, "close"):
952
                            stream.close()
953
            finally:
954
                # Issue #19523: call unconditionally to
955
                # prevent a handler leak when delay is set
956
                # Also see Issue #42378: we also rely on
957
                # self._closed being set to True there
958
                logging.StreamHandler.close(self)
959
        finally:
960
            self.release()
961

962
    def emit(self, record):
963
        if self.stream is None:
964
            # TODO: more robust is_fbcode test
965
            import torch.version
966

967
            TRACE_LOG_DIR = "/logs"
968
            open_func = self._builtin_open
969

970
            ok = False
971
            import torch.version as torch_version
972

973
            if hasattr(torch_version, "git_version"):
974
                log.info("LazyFbTraceHandler: disabled because not fbcode")
975
            elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
976
                log.info(
977
                    "LazyFbTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
978
                )
979
            elif not os.path.exists(TRACE_LOG_DIR):
980
                log.info(
981
                    "LazyFbTraceHandler: disabled because %s does not exist",
982
                    TRACE_LOG_DIR,
983
                )
984
            elif not os.access(TRACE_LOG_DIR, os.W_OK):
985
                log.info(
986
                    "LazyFbTraceHandler: disabled because %s is not writeable",
987
                    TRACE_LOG_DIR,
988
                )
989
            else:
990
                ok = True
991

992
            if ok:
993
                ranksuffix = ""
994
                if dist.is_available() and dist.is_initialized():
995
                    ranksuffix = f"rank_{dist.get_rank()}_"
996
                self.stream = tempfile.NamedTemporaryFile(
997
                    mode="w+",
998
                    suffix=".log",
999
                    prefix=f"dedicated_log_torch_trace_{ranksuffix}",
1000
                    dir=TRACE_LOG_DIR,
1001
                    delete=False,
1002
                )
1003
                log.info("LazyFbTraceHandler: logging to %s", self.stream.name)
1004
            else:
1005
                # We go poof, remove and no-op
1006
                trace_log.removeHandler(self)
1007
                return
1008
        if self.stream:
1009
            super().emit(record)
1010

1011

1012
@functools.lru_cache(None)
1013
def warning_once(logger_obj, *args, **kwargs):
1014
    """
1015
    This function is similar to `logger.warning()`, but will emit the warning with the same message only once
1016
    Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
1017
    The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
1018
    another type of cache that includes the caller frame information in the hashing function.
1019
    """
1020
    logger_obj.warning(*args, **kwargs)
1021

1022

1023
class LazyString:
1024
    def __init__(self, func, *args, **kwargs):
1025
        self.func = func
1026
        self.args = args
1027
        self.kwargs = kwargs
1028

1029
    def __str__(self):
1030
        return self.func(*self.args, **self.kwargs)
1031

1032

1033
def trace_structured(
1034
    name: str,
1035
    # NB: metadata expected to be dict so adding more info is forward compatible
1036
    # Tuple[str, int] is a special case for string interning
1037
    metadata_fn: Callable[[], Union[Dict[str, Any], Tuple[str, int]]] = dict,
1038
    *,
1039
    payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
1040
    suppress_context: bool = False,
1041
):
1042
    """
1043
    metadata is an arbitrary JSON compatible struct, but it's expected to not be
1044
    too long (e.g., less than 1MB)
1045

1046
    payload is an arbitrary string, which can be arbitrarily long (but expected to have
1047
    newlines so no lines are too long)
1048
    """
1049
    assert "name" not in ["rank", "frame_id", "frame_compile_id", "attempt"]
1050
    assert callable(
1051
        metadata_fn
1052
    ), f"metadata_fn should be callable, but got {type(metadata_fn)}"
1053
    assert callable(
1054
        payload_fn
1055
    ), f"payload_fn should be callable, but got {type(payload_fn)}"
1056
    # trace_log never propagates and is ALWAYS DEBUG, so also check that there
1057
    # are handlers instead of checking the log level
1058
    if trace_log.handlers:
1059
        record: Dict[str, object] = {}
1060
        record[name] = metadata_fn()
1061
        if not suppress_context:
1062
            # TODO: Actually, the rank probably should just be emitted once at
1063
            # the top, and not repeatedly spammed in all the logs, since it
1064
            # never changes and we assume no interleaving
1065
            if dist.is_available() and dist.is_initialized():
1066
                record["rank"] = dist.get_rank()
1067
            if (
1068
                trace_id := torch._guards.CompileContext.current_trace_id()
1069
            ) is not None:
1070
                record["frame_id"] = trace_id.compile_id.frame_id
1071
                record["frame_compile_id"] = trace_id.compile_id.frame_compile_id
1072
                record["attempt"] = trace_id.attempt
1073
        payload = payload_fn()
1074
        if payload is not None:
1075
            if not isinstance(payload, str):
1076
                if isinstance(payload, list):
1077
                    # special case to look better
1078
                    payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
1079
                else:
1080
                    # force newlines so we are unlikely to overflow line limit
1081
                    payload = json.dumps(payload, indent=0)
1082
            h = hashlib.md5()
1083
            h.update(payload.encode("utf-8"))
1084
            record["has_payload"] = h.hexdigest()
1085
        trace_log.debug(
1086
            "", extra={"metadata": record, "payload": payload}, stacklevel=2
1087
        )
1088

1089

1090
import torch._guards
1091
import torch._utils_internal
1092
import torch.distributed as dist
1093

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.