1
from __future__ import annotations
13
from abc import ABC, abstractmethod
14
from contextlib import contextmanager
30
from torch.utils import _pytree as pytree
31
from torch.utils._traceback import CapturedTraceback
32
from torch.utils.weak import WeakTensorKeyDictionary
34
log = logging.getLogger(__name__)
38
# Import the following modules during type checking to enable code intelligence features,
39
# such as auto-completion in tools like pylance, even when these modules are not explicitly
40
# imported in user code.
46
torch._guards is the definitional source of truth for general purpose guard structures.
48
An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
49
and no guard installation notions here.
53
class CompileId(NamedTuple):
55
# This id is per-frame, and counts how many times we've compiled this
56
# frame. This could have been a global id but having this be per-frame
57
# gives you a better intuitive sense for how many recompiles have occurred
60
# TODO: consider also tracking the recompilation count
63
return f"{self.frame_id}/{self.frame_compile_id}"
66
class TraceId(NamedTuple):
68
# This starts off as 0, and every time we restart analysis it goes
74
return str(self.compile_id)
76
return f"{self.compile_id}_{self.attempt}"
79
class GuardSource(enum.Enum):
88
GLOBAL_FSDP_MODULE = 8
91
def is_fsdp_module(self) -> bool:
92
return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
94
def is_nn_module(self) -> bool:
98
GuardSource.GLOBAL_NN_MODULE,
99
GuardSource.LOCAL_NN_MODULE,
101
or self.is_fsdp_module()
107
GuardSource.LOCAL_NN_MODULE,
108
GuardSource.LOCAL_FSDP_MODULE,
113
Base class for a "GuardBuilder" role.
115
The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
116
confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
117
to torchdynamo's GuardBuilder.
119
Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
120
on GuardSource's select function.
122
There is value in keeping this GuardBuilderBase empty to keep layering clean.
126
class GuardBuilderBase:
130
class ShapeGuard(NamedTuple):
132
stack: CapturedTraceback
135
@dataclasses.dataclass
137
# originating_source is the source that called the make_guard method to
138
# construct this guard object. The property name specifies what exactly it
139
# is the guard is guarding on. The meaning of the name is dependent on the
140
# create_fn; you must look at the use-site inside create_fn to know what
143
# That being said, although you might think this is just a "name", name is
144
# usually an arbitrary Python expression that will be evaluated with all
145
# globals (and locals, if you create a LOCAL guard) to extract the Python
146
# object that we want to perform guard tests on. This evaluation
147
# typically happens in GuardBuilder.eval. In these cases, name is
148
# typically produced by originating_source.name() (not to be confused with
149
# GuardSource - the property source).
151
# Occasionally, name is not a valid Python expression; sometimes
152
# it is meaningless. Example create_fns that are like this include
153
# GRAD_MODE and SHAPE_ENV.
154
originating_source: Source
155
create_fn: Callable[[GuardBuilderBase, Guard], None]
157
# Export only. These values are written to at time of guard check_fn creation.
158
guard_types: Optional[List[str]] = None
159
code_list: Optional[List[str]] = None
160
obj_weakref: Optional[object] = None
161
guarded_class_weakref: Optional[type] = None
163
stack: Optional[CapturedTraceback] = None
164
user_stack: Optional[traceback.StackSummary] = None
165
_hash: Optional[int] = None
168
if self._hash is None:
169
self._hash = hash((self.name, self.source, id(self.create_fn)))
174
self.source.value if self.source else -1,
177
self.inner_create_fn().__code__.co_firstlineno,
180
def __lt__(self, other):
181
return self.sort_key() < other.sort_key()
183
def inner_create_fn(self):
184
if isinstance(self.create_fn, functools.partial):
185
return self.create_fn.func
187
return self.create_fn
190
def name(self) -> str:
191
return self.originating_source.name()
194
def source(self) -> GuardSource:
195
return self.originating_source.guard_source()
198
def weakref_to_str(obj_weakref):
200
This is a workaround of a Python weakref bug.
202
`obj_weakref` is instance returned by `weakref.ref`,
203
`str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
205
class MyConfig(dict):
206
def __getattr__(self, x):
209
obj = MyConfig(offset=5)
210
obj_weakref = weakref.ref(obj)
211
str(obj_weakref) # raise error: KeyError: '__name__'
213
if isinstance(obj_weakref, weakref.ReferenceType):
216
return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
218
return f"<weakref at {hex(id(obj_weakref))}; dead>"
220
return str(obj_weakref)
224
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
226
'guard_types': {self.guard_types},
227
'code': {self.code_list},
228
'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
229
'guarded_class': {self.guarded_class_weakref}
235
output = f"Name: {repr(self.name)}\n"
236
source = self.source.name.lower() if self.source else ""
237
output += f" Source: {source}\n"
238
output += f" Create Function: {self.inner_create_fn().__name__}\n"
239
output += f" Guard Types: {self.guard_types}\n"
240
output += f" Code List: {self.code_list}\n"
241
output += f" Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
242
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
245
def create(self, builder: GuardBuilderBase):
247
return self.create_fn(builder, self)
249
log.error("Error while creating guard:\n%s", str(self).rstrip())
251
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
254
def is_nn_module(self):
255
return self.source.is_nn_module()
257
def is_fsdp_module(self):
258
return self.source.is_fsdp_module()
261
return self.source.is_local()
263
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
264
if not self.guard_types:
265
self.guard_types = list()
267
self.guard_types.append(guard_type)
269
assert self.guarded_class_weakref in (
272
), "Guarded class id must be identical, or None"
273
self.guarded_class_weakref = guarded_class
275
if not self.code_list:
276
self.code_list = code_list
278
self.code_list.extend(code_list)
280
assert self.obj_weakref in (
283
), "Guarded object must be identical, or None"
284
self.obj_weakref = obj_weakref
290
Parent structure for guard env expressions.
291
A GuardEnvExpr can have any subtype.
292
Note: All subtypes must be handled exhaustively in
293
torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
297
@dataclasses.dataclass
303
A class representing a pair of duplicate inputs.
304
input_pos_a and input_pos_b are input positions we have deduped.
308
@dataclasses.dataclass
309
class DuplicateInputs(GuardEnvExpr):
310
input_source_a: Source
311
input_source_b: Source
313
def __post_init__(self):
314
assert self.input_source_a != self.input_source_b
318
Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
320
copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
321
can also be taken in at restore_graphstate(T) calls.
323
When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
324
does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
326
In the future, it will have a closer coupling to a generic Checkpoint management system.
330
class Checkpointable(ABC, Generic[T]):
332
def copy_graphstate(self) -> T:
336
def restore_graphstate(self, state: T):
340
class GuardsCheckpointState:
342
The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
345
dynamo_guards: Set[Guard] = set()
347
def __init__(self, dynamo_guards):
348
self.dynamo_guards = dynamo_guards
350
def diff(self, other):
352
Produces a delta against another GuardsCheckpointState.
354
Returns None if no delta is found, otherwise, return a set() of mismatched
357
r = self.dynamo_guards.difference(other.dynamo_guards)
362
def __eq__(self, other):
363
return self.diff(other) is None
366
class ModuleContextCheckpointState:
367
nn_modules: Dict[str, torch.nn.Module] = {}
369
def __init__(self, nn_modules):
370
self.nn_modules = nn_modules
372
def diff(self, other):
374
Produces a delta against another ModuleContextCheckpointState.
376
Returns None if no delta is found, otherwise, return a set() of mismatched
379
r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
384
def __eq__(self, other):
385
return self.diff(other) is None
388
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
390
self.nn_modules: Dict[str, Any] = {}
392
def copy_graphstate(self):
393
return ModuleContextCheckpointState(dict(self.nn_modules))
395
def restore_graphstate(self, state):
396
assert isinstance(state, ModuleContextCheckpointState)
397
self.nn_modules = state.nn_modules
400
class GlobalContextCheckpointState:
401
global_state: Dict[str, Tuple[Callable, ...]] = {}
403
def __init__(self, global_states):
404
self.global_state = global_states
406
def diff(self, other):
408
Produces a delta against another GlobalContextCheckpointState.
410
Returns None if no delta is found, otherwise, return a set() of mismatched
413
r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
418
def __eq__(self, other):
419
return self.diff(other) is None
422
class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
424
This keeps track of the global torch state during tracing of a function.
425
For example, torch.is_grad_enabled.
428
_supported_global_states = {
430
"torch_function_enabled",
432
"autocast_cpu_enabled",
433
"autocast_gpu_dtype",
434
"autocast_cpu_dtype",
435
"autocast_cache_enabled",
439
self.global_state: Dict[str, Tuple[Callable, ...]] = {}
441
def copy_graphstate(self):
442
return GlobalContextCheckpointState(dict(self.global_state))
444
def restore_graphstate(self, state):
445
assert isinstance(state, GlobalContextCheckpointState)
446
self.global_state = state.global_state
448
len(self.global_state) == len(self._supported_global_states)
449
and set(self.global_state.keys()) == self._supported_global_states
450
), "Global state mismatch"
451
for func, args in self.global_state.values():
456
A GuardsContext is a checkpointable representation of all the guards in the current tracing
457
context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
458
directly outside of it. For passing around internal state representations of this object,
459
prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
463
# Like a Set[Guard] but will record the user stack on all guards at the
464
# time they were installed at their destination
466
def __init__(self, inner=None):
472
return iter(self.inner)
475
return len(self.inner)
477
# Subtraction along with bool is typically used to determine the delta of
478
# added guards between checkpoints for higher order ops
479
def __sub__(self, other):
480
return GuardsSet(self.inner - other.inner)
483
return bool(self.inner)
485
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
486
if guard in self.inner:
488
if collect_debug_stack:
489
if guard.stack is None:
490
guard.stack = CapturedTraceback.extract(skip=1 + skip)
491
if guard.user_stack is None:
492
guard.user_stack = TracingContext.extract_stack()
493
self.inner.add(guard)
495
def update(self, *others: Set[Guard]):
501
class GuardsContext(Checkpointable[GuardsCheckpointState]):
503
self.dynamo_guards: GuardsSet = GuardsSet()
504
self.aotautograd_guards: List[GuardEnvExpr] = []
506
def copy_graphstate(self):
507
return GuardsCheckpointState(set(self.dynamo_guards.inner))
509
def restore_graphstate(self, state):
510
# NB: "steals" the passed in state
511
assert isinstance(state, GuardsCheckpointState)
512
self.dynamo_guards = GuardsSet(state.dynamo_guards)
515
_TLS = threading.local()
518
TracingContext is the source of truth for all currently accumulated information
519
needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
520
are open to managing their own TracingContext with that in mind.
522
The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
523
having to plumb complex subsystems across multiple verticals.
525
Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
526
Accessing the current tracing context via
527
TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
528
to plumb objects back up to where frame interpretation happened.
530
Note that you can end up with multiple TracingContext for a single compilation
531
of a frame, as we reset the TracingContext whenever we restart analysis.
532
CompileContext is a more overarching context that encompasses multiple restarts.
538
def get() -> CompileContext:
539
assert _TLS.compile_context is not None
540
return _TLS.compile_context
543
def try_get() -> Optional[CompileContext]:
544
return getattr(_TLS, "compile_context", None)
546
def __init__(self, compile_id):
547
assert compile_id is None or isinstance(compile_id, CompileId)
548
self.compile_id: Optional[CompileId] = compile_id
552
def current_compile_id():
553
self = CompileContext.try_get()
556
return self.compile_id
559
def current_trace_id():
560
self = CompileContext.try_get()
563
if self.compile_id is None:
565
return TraceId(self.compile_id, self.attempt)
570
Provides the currently installed TracingContext, or None.
572
Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
577
def try_get() -> Optional[TracingContext]:
578
return getattr(_TLS, "tracing_context", None)
581
def get() -> TracingContext:
582
if ctx := TracingContext.try_get():
585
"TracingContext.get() must be called within an ongoing trace."
588
def __init__(self, fake_mode):
589
self.guards_context = GuardsContext()
590
self.module_context = ModuleContext()
591
self.global_context = GlobalContext()
592
self.fake_mode = fake_mode
593
self.frame_summary_stack = []
594
# This is morally part of frame_summary_stack, but it is kept separate
595
# for clarity. As we process a frame, this variable gets updated
596
# to keep track of what line we are in the function. We make a
597
# function call, this gets cleared and the frame location is pushed
598
# to frame_summary_stack (prepping this variable for the inner frame's
600
self.loc_in_frame = None
601
# this is only set after aot_autograd
602
self.fw_metadata = None
603
self.params_flat = None
604
# this is for extended return calling convention from backend
605
# compiler to aot_autograd
606
# Per output, what the compiler specified stride of the output is,
607
# or None if no stride is known. This is always the HINT, it
608
# is never a SymInt (it would be better if it was a SymInt, but
609
# I can't conveniently get this from Inductor atm. Also, be
610
# careful not to accidentally induce guards on the SymInt if
611
# you ever do change this in aot_autograd.py; you should check
612
# on permutations preferentially.)
613
self.output_strides: Optional[List[Optional[List[int]]]] = None
614
# When this is True, whenever we encounter an int in Dynamo tracing,
615
# we will (1) force unspec it and (2) force it as a size-like unbacked
616
# integer. This is currently used when processing certain lists of
617
# ints that are known to be size-like and may have 0/1 entries that we
618
# must not specialize on.
619
self.force_unspec_int_unbacked_size_like = False
620
# See note [Tensor Fakification and Symbol Caching]
621
self.tensor_to_context = WeakTensorKeyDictionary()
623
# If this true, Aot Autograd will return output Fake Tensors with appropiate
624
# meta on the first invocation
625
# see note: [Returning Fake Tensors on First AOT Autograd Call]
626
self.fakify_first_call = False
629
# Look at the note in output_graph.py in function `save_global_state`
630
# for the context on clearing global context.
631
self.global_context.global_state = {}
637
ctx = TracingContext.get()
639
for key in kwargs.keys():
640
# KeyError on invalid entry
641
prior[key] = getattr(ctx, key)
642
for key, val in kwargs.items():
643
setattr(ctx, key, val)
647
for key, val in prior.items():
648
setattr(ctx, key, val)
652
self = TracingContext.try_get()
654
return traceback.StackSummary()
655
stack = self.frame_summary_stack
656
if self.loc_in_frame is not None:
657
stack = stack + [self.loc_in_frame]
658
return traceback.StackSummary.from_list(stack)
660
# Call this when you want to call into some code that isn't necessarily
661
# associated with the current frame state
663
@contextlib.contextmanager
665
tc = TracingContext.get()
666
with unittest.mock.patch.object(
667
tc, "frame_summary_stack", []
668
), unittest.mock.patch.object(tc, "loc_in_frame", None):
671
except Exception as e:
672
# Prevent real_stack from getting attached
674
# The invariant is that if an Exception as real_stack, we've
675
# appropriately attached a user stack and we no longer need to
676
# attach anything. Because we cannot conveniently interpose
677
# when an exception is thrown, we instead interpose everywhere
678
# we set what the user stack is set (using the context
679
# manager). However, our compiler stack does "tail calls"
680
# (when it calls into user compiler), at which point the
681
# parent exception frames would incorrectly attach an
684
# However, if, somehow, someone raised an exception with this
685
# scope that had a stack (for example, because they are
686
# restoring the user stack state appropriately as they process
687
# node by node), we should respect it. Thus, we cannot
688
# unconditionally set None.
689
if not hasattr(e, "real_stack"):
690
e.real_stack = None # type: ignore[attr-defined]
694
@contextlib.contextmanager
695
def current_frame(frame_summary):
696
# frame_summary can be None to solely take advantage of real_stack
697
# attachment to thrown exceptions
698
tc = TracingContext.get()
699
if frame_summary is not None:
700
tc.frame_summary_stack.append(frame_summary)
701
old = tc.loc_in_frame
702
tc.loc_in_frame = None
705
except Exception as e:
706
if not hasattr(e, "real_stack"):
707
e.real_stack = tc.extract_stack() # type: ignore[attr-defined]
710
if frame_summary is not None:
711
tc.frame_summary_stack.pop()
712
tc.loc_in_frame = old
715
@contextlib.contextmanager
716
def report_output_strides():
717
tc = TracingContext.try_get()
721
old_output_strides = tc.output_strides
722
tc.output_strides = []
724
yield tc.output_strides
726
tc.output_strides = old_output_strides
729
def set_current_loc(filename, lineno, frame_name):
730
TracingContext.get().loc_in_frame = traceback.FrameSummary(
731
filename, lineno, frame_name
736
def compile_context(context: CompileContext):
737
old_context = getattr(_TLS, "compile_context", None)
738
_TLS.compile_context = context
742
_TLS.compile_context = old_context
746
def tracing(context: Optional[TracingContext]):
748
This function installs the passed in tracing context as a dynamic scoped
751
Calls to TracingContext.get() while not under a `with tracing()` context
754
old_context = getattr(_TLS, "tracing_context", None)
755
_TLS.tracing_context = context
758
except Exception as e:
759
if not hasattr(e, "real_stack") and context is not None:
760
e.real_stack = context.extract_stack() # type: ignore[attr-defined]
765
and context.fake_mode is not None
766
and context.fake_mode.shape_env is not None
768
context.fake_mode.shape_env.cleanup()
769
_TLS.tracing_context = old_context
772
# Subclasses can be found in torch/_dynamo/source.py
773
# TODO(voz): Consider a toplevel torch/_source.py
774
@dataclasses.dataclass(frozen=True)
776
def is_dict_key(self):
779
def reconstruct(self, codegen):
780
raise NotImplementedError()
782
def guard_source(self) -> GuardSource:
783
raise NotImplementedError()
785
def name(self) -> str:
786
raise NotImplementedError()
788
def make_guard(self, fn) -> Guard:
789
if self.guard_source() is GuardSource.CONSTANT:
790
raise NotImplementedError()
791
return Guard(self, fn)
793
def is_nn_module(self) -> bool:
794
return self.guard_source().is_nn_module()
797
# Subclasses can be found in torch/_dynamo/source.py
798
@dataclasses.dataclass(frozen=True)
799
class ChainedSource(Source):
802
def is_dict_key(self):
803
# Recurse until you either hit a ConstDictKey or a Source
804
return self.base.is_dict_key()
807
def detect_fake_mode(inputs: Any = None):
809
Attempts to "detect" what the current fake mode is. If there is one ambiently
810
available from TracingContext, we preferentially use that. Otherwise, we
811
heuristically detect the fake mode via the following sources, in order of
814
- Currently active fake mode on stack
815
- Fake mode associated with passed in tensors (inputs does not
816
have to be flattened)
818
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
822
if context := TracingContext.try_get():
823
fake_mode = context.fake_mode
824
if fake_mode is not None:
825
fake_modes.append((fake_mode, "tracing context", 0))
827
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
829
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
830
if isinstance(m, FakeTensorMode):
831
fake_modes.append((m, "active fake mode", i))
833
flat_inputs = pytree.tree_leaves(inputs)
834
for i, flat_input in enumerate(flat_inputs):
835
if isinstance(flat_input, FakeTensor):
836
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
839
fake_mode, desc1, i1 = fake_modes[0]
840
for m, desc2, i2 in fake_modes[1:]:
841
assert fake_mode is m, (
842
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
843
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
844
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
851
def active_fake_mode():
853
Inspects the dispatch mode stack for an active fake mode and returns it.
854
Returns None if no fake mode is active.
856
from torch._subclasses.fake_tensor import FakeTensorMode
857
from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
859
for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
860
if isinstance(m, FakeTensorMode):