pytorch

Форк
0
/
_guards.py 
863 строки · 27.7 Кб
1
from __future__ import annotations
2

3
import contextlib
4

5
import dataclasses
6
import enum
7
import functools
8
import logging
9
import threading
10
import traceback
11
import unittest.mock
12
import weakref
13
from abc import ABC, abstractmethod
14
from contextlib import contextmanager
15
from typing import (
16
    Any,
17
    Callable,
18
    Dict,
19
    Generic,
20
    List,
21
    NamedTuple,
22
    Optional,
23
    Set,
24
    Tuple,
25
    TYPE_CHECKING,
26
    TypeVar,
27
)
28

29
import torch
30
from torch.utils import _pytree as pytree
31
from torch.utils._traceback import CapturedTraceback
32
from torch.utils.weak import WeakTensorKeyDictionary
33

34
log = logging.getLogger(__name__)
35

36

37
if TYPE_CHECKING:
38
    # Import the following modules during type checking to enable code intelligence features,
39
    # such as auto-completion in tools like pylance, even when these modules are not explicitly
40
    # imported in user code.
41

42
    import sympy
43

44

45
"""
46
torch._guards is the definitional source of truth for general purpose guard structures.
47

48
An important thing to keep in mind here is the preservation of layering. There should be no dynamo notions,
49
and no guard installation notions here.
50
"""
51

52

53
class CompileId(NamedTuple):
54
    frame_id: int
55
    # This id is per-frame, and counts how many times we've compiled this
56
    # frame.  This could have been a global id but having this be per-frame
57
    # gives you a better intuitive sense for how many recompiles have occurred
58
    # so far.
59
    frame_compile_id: int
60
    # TODO: consider also tracking the recompilation count
61

62
    def __str__(self):
63
        return f"{self.frame_id}/{self.frame_compile_id}"
64

65

66
class TraceId(NamedTuple):
67
    compile_id: CompileId
68
    # This starts off as 0, and every time we restart analysis it goes
69
    # up by one
70
    attempt: int
71

72
    def __str__(self):
73
        if self.attempt == 0:
74
            return str(self.compile_id)
75
        else:
76
            return f"{self.compile_id}_{self.attempt}"
77

78

79
class GuardSource(enum.Enum):
80
    LOCAL = 0
81
    GLOBAL = 1
82
    LOCAL_NN_MODULE = 2
83
    GLOBAL_NN_MODULE = 3
84
    CONSTANT = 4
85
    RANDOM_VALUE = 5
86
    SHAPE_ENV = 6
87
    LOCAL_FSDP_MODULE = 7
88
    GLOBAL_FSDP_MODULE = 8
89
    BACKWARD_STATE = 9
90

91
    def is_fsdp_module(self) -> bool:
92
        return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE)
93

94
    def is_nn_module(self) -> bool:
95
        return (
96
            self
97
            in (
98
                GuardSource.GLOBAL_NN_MODULE,
99
                GuardSource.LOCAL_NN_MODULE,
100
            )
101
            or self.is_fsdp_module()
102
        )
103

104
    def is_local(self):
105
        return self in (
106
            GuardSource.LOCAL,
107
            GuardSource.LOCAL_NN_MODULE,
108
            GuardSource.LOCAL_FSDP_MODULE,
109
        )
110

111

112
"""
113
Base class for a "GuardBuilder" role.
114

115
The GuardBuilderBase role is to represent a scope within which to build a guard. The name is a little
116
confusing, as its not a builder, but for the sake of avoiding a lot of renames and keeping the original reference
117
to torchdynamo's GuardBuilder.
118

119
Note: create_fn is invoked with a GuardBuilderBase and a Guard. A GuardBuilder is chosen based
120
on GuardSource's select function.
121

122
There is value in keeping this GuardBuilderBase empty to keep layering clean.
123
"""
124

125

126
class GuardBuilderBase:
127
    pass
128

129

130
class ShapeGuard(NamedTuple):
131
    expr: sympy.Expr
132
    stack: CapturedTraceback
133

134

135
@dataclasses.dataclass
136
class Guard:
137
    # originating_source is the source that called the make_guard method to
138
    # construct this guard object. The property name specifies what exactly it
139
    # is the guard is guarding on.  The meaning of the name is dependent on the
140
    # create_fn; you must look at the use-site inside create_fn to know what
141
    # name means.
142
    #
143
    # That being said, although you might think this is just a "name", name is
144
    # usually an arbitrary Python expression that will be evaluated with all
145
    # globals (and locals, if you create a LOCAL guard) to extract the Python
146
    # object that we want to perform guard tests on.  This evaluation
147
    # typically happens in GuardBuilder.eval.  In these cases, name is
148
    # typically produced by originating_source.name() (not to be confused with
149
    # GuardSource - the property source).
150
    #
151
    # Occasionally, name is not a valid Python expression; sometimes
152
    # it is meaningless.  Example create_fns that are like this include
153
    # GRAD_MODE and SHAPE_ENV.
154
    originating_source: Source
155
    create_fn: Callable[[GuardBuilderBase, Guard], None]
156

157
    # Export only. These values are written to at time of guard check_fn creation.
158
    guard_types: Optional[List[str]] = None
159
    code_list: Optional[List[str]] = None
160
    obj_weakref: Optional[object] = None
161
    guarded_class_weakref: Optional[type] = None
162

163
    stack: Optional[CapturedTraceback] = None
164
    user_stack: Optional[traceback.StackSummary] = None
165
    _hash: Optional[int] = None
166

167
    def __hash__(self):
168
        if self._hash is None:
169
            self._hash = hash((self.name, self.source, id(self.create_fn)))
170
        return self._hash
171

172
    def sort_key(self):
173
        return (
174
            self.source.value if self.source else -1,
175
            len(self.name),
176
            self.name,
177
            self.inner_create_fn().__code__.co_firstlineno,
178
        )
179

180
    def __lt__(self, other):
181
        return self.sort_key() < other.sort_key()
182

183
    def inner_create_fn(self):
184
        if isinstance(self.create_fn, functools.partial):
185
            return self.create_fn.func
186
        else:
187
            return self.create_fn
188

189
    @property
190
    def name(self) -> str:
191
        return self.originating_source.name()
192

193
    @property
194
    def source(self) -> GuardSource:
195
        return self.originating_source.guard_source()
196

197
    @staticmethod
198
    def weakref_to_str(obj_weakref):
199
        """
200
        This is a workaround of a Python weakref bug.
201

202
        `obj_weakref` is instance returned by `weakref.ref`,
203
        `str(obj_weakref)` is buggy if the original obj overrides __getattr__, e.g:
204

205
            class MyConfig(dict):
206
                def __getattr__(self, x):
207
                    return self[x]
208

209
            obj = MyConfig(offset=5)
210
            obj_weakref = weakref.ref(obj)
211
            str(obj_weakref)  # raise error: KeyError: '__name__'
212
        """
213
        if isinstance(obj_weakref, weakref.ReferenceType):
214
            obj = obj_weakref()
215
            if obj is not None:
216
                return f"<weakref at {hex(id(obj_weakref))}; to '{obj.__class__.__name__}' at {hex(id(obj))}>"
217
            else:
218
                return f"<weakref at {hex(id(obj_weakref))}; dead>"
219
        else:
220
            return str(obj_weakref)
221

222
    def __repr__(self):
223
        s = f"""
224
        {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
225
        {{
226
            'guard_types': {self.guard_types},
227
            'code': {self.code_list},
228
            'obj_weakref': {self.weakref_to_str(self.obj_weakref)}
229
            'guarded_class': {self.guarded_class_weakref}
230
        }}
231
        """
232
        return s
233

234
    def __str__(self):
235
        output = f"Name: {repr(self.name)}\n"
236
        source = self.source.name.lower() if self.source else ""
237
        output += f"    Source: {source}\n"
238
        output += f"    Create Function: {self.inner_create_fn().__name__}\n"
239
        output += f"    Guard Types: {self.guard_types}\n"
240
        output += f"    Code List: {self.code_list}\n"
241
        output += f"    Object Weakref: {self.weakref_to_str(self.obj_weakref)}\n"
242
        output += f"    Guarded Class Weakref: {self.guarded_class_weakref}\n"
243
        return output
244

245
    def create(self, builder: GuardBuilderBase):
246
        try:
247
            return self.create_fn(builder, self)
248
        except Exception:
249
            log.error("Error while creating guard:\n%s", str(self).rstrip())
250
            if self.stack:
251
                log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
252
            raise
253

254
    def is_nn_module(self):
255
        return self.source.is_nn_module()
256

257
    def is_fsdp_module(self):
258
        return self.source.is_fsdp_module()
259

260
    def is_local(self):
261
        return self.source.is_local()
262

263
    def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
264
        if not self.guard_types:
265
            self.guard_types = list()
266

267
        self.guard_types.append(guard_type)
268

269
        assert self.guarded_class_weakref in (
270
            guarded_class,
271
            None,
272
        ), "Guarded class id must be identical, or None"
273
        self.guarded_class_weakref = guarded_class
274

275
        if not self.code_list:
276
            self.code_list = code_list
277
        else:
278
            self.code_list.extend(code_list)
279

280
        assert self.obj_weakref in (
281
            obj_weakref,
282
            None,
283
        ), "Guarded object must be identical, or None"
284
        self.obj_weakref = obj_weakref
285

286

287
T = TypeVar("T")
288

289
"""
290
Parent structure for guard env expressions.
291
A GuardEnvExpr can have any subtype.
292
Note: All subtypes must be handled exhaustively in
293
torch._dynamo.guards._parse_guard_env_guards to avoid a RuntimeError.
294
"""
295

296

297
@dataclasses.dataclass
298
class GuardEnvExpr:
299
    pass
300

301

302
"""
303
A class representing a pair of duplicate inputs.
304
input_pos_a and input_pos_b are input positions we have deduped.
305
"""
306

307

308
@dataclasses.dataclass
309
class DuplicateInputs(GuardEnvExpr):
310
    input_source_a: Source
311
    input_source_b: Source
312

313
    def __post_init__(self):
314
        assert self.input_source_a != self.input_source_b
315

316

317
"""
318
Checkpointable is an interface for driving state snapshotting, left purposely vague for now.
319

320
copy_graphstate() -> T, a somewhat legacy name, is expected to emit a snapshot of any type that
321
can also be taken in at restore_graphstate(T) calls.
322

323
When to snapshot, is, at the moment, an implementation detail of upstream callers. Checkpointable
324
does not provide any garuantees around consistency, idempotency, or safety of calling its APIs, yet.
325

326
In the future, it will have a closer coupling to a generic Checkpoint management system.
327
"""
328

329

330
class Checkpointable(ABC, Generic[T]):
331
    @abstractmethod
332
    def copy_graphstate(self) -> T:
333
        ...
334

335
    @abstractmethod
336
    def restore_graphstate(self, state: T):
337
        ...
338

339

340
class GuardsCheckpointState:
341
    """
342
    The GuardCheckpointState - it is the T of Checkpointable[T] for GuardsContext
343
    """
344

345
    dynamo_guards: Set[Guard] = set()
346

347
    def __init__(self, dynamo_guards):
348
        self.dynamo_guards = dynamo_guards
349

350
    def diff(self, other):
351
        """
352
        Produces a delta against another GuardsCheckpointState.
353

354
        Returns None if no delta is found, otherwise, return a set() of mismatched
355
        Guard type objects.
356
        """
357
        r = self.dynamo_guards.difference(other.dynamo_guards)
358
        if len(r) == 0:
359
            return None
360
        return r
361

362
    def __eq__(self, other):
363
        return self.diff(other) is None
364

365

366
class ModuleContextCheckpointState:
367
    nn_modules: Dict[str, torch.nn.Module] = {}
368

369
    def __init__(self, nn_modules):
370
        self.nn_modules = nn_modules
371

372
    def diff(self, other):
373
        """
374
        Produces a delta against another ModuleContextCheckpointState.
375

376
        Returns None if no delta is found, otherwise, return a set() of mismatched
377
        module key names.
378
        """
379
        r = set(self.nn_modules.keys()).difference(set(other.nn_modules.keys()))
380
        if len(r) == 0:
381
            return None
382
        return r
383

384
    def __eq__(self, other):
385
        return self.diff(other) is None
386

387

388
class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
389
    def __init__(self):
390
        self.nn_modules: Dict[str, Any] = {}
391

392
    def copy_graphstate(self):
393
        return ModuleContextCheckpointState(dict(self.nn_modules))
394

395
    def restore_graphstate(self, state):
396
        assert isinstance(state, ModuleContextCheckpointState)
397
        self.nn_modules = state.nn_modules
398

399

400
class GlobalContextCheckpointState:
401
    global_state: Dict[str, Tuple[Callable, ...]] = {}
402

403
    def __init__(self, global_states):
404
        self.global_state = global_states
405

406
    def diff(self, other):
407
        """
408
        Produces a delta against another GlobalContextCheckpointState.
409

410
        Returns None if no delta is found, otherwise, return a set() of mismatched
411
        global key names.
412
        """
413
        r = set(self.global_state.keys()).difference(set(other.global_state.keys()))
414
        if len(r) == 0:
415
            return None
416
        return r
417

418
    def __eq__(self, other):
419
        return self.diff(other) is None
420

421

422
class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
423
    """
424
    This keeps track of the global torch state during tracing of a function.
425
    For example, torch.is_grad_enabled.
426
    """
427

428
    _supported_global_states = {
429
        "grad_enabled",
430
        "torch_function_enabled",
431
        "autocast_enabled",
432
        "autocast_cpu_enabled",
433
        "autocast_gpu_dtype",
434
        "autocast_cpu_dtype",
435
        "autocast_cache_enabled",
436
    }
437

438
    def __init__(self):
439
        self.global_state: Dict[str, Tuple[Callable, ...]] = {}
440

441
    def copy_graphstate(self):
442
        return GlobalContextCheckpointState(dict(self.global_state))
443

444
    def restore_graphstate(self, state):
445
        assert isinstance(state, GlobalContextCheckpointState)
446
        self.global_state = state.global_state
447
        assert (
448
            len(self.global_state) == len(self._supported_global_states)
449
            and set(self.global_state.keys()) == self._supported_global_states
450
        ), "Global state mismatch"
451
        for func, args in self.global_state.values():
452
            func(args)
453

454

455
"""
456
A GuardsContext is a checkpointable representation of all the guards in the current tracing
457
context. It's lifecycle is bound 1:1 to the tracing context, and it should never be instantiated
458
directly outside of it. For passing around internal state representations of this object,
459
prefer to extract them with copy_graphstate to produce a GuardsCheckpointState.
460
"""
461

462

463
# Like a Set[Guard] but will record the user stack on all guards at the
464
# time they were installed at their destination
465
class GuardsSet:
466
    def __init__(self, inner=None):
467
        if inner is None:
468
            inner = set()
469
        self.inner = inner
470

471
    def __iter__(self):
472
        return iter(self.inner)
473

474
    def __len__(self):
475
        return len(self.inner)
476

477
    # Subtraction along with bool is typically used to determine the delta of
478
    # added guards between checkpoints for higher order ops
479
    def __sub__(self, other):
480
        return GuardsSet(self.inner - other.inner)
481

482
    def __bool__(self):
483
        return bool(self.inner)
484

485
    def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
486
        if guard in self.inner:
487
            return
488
        if collect_debug_stack:
489
            if guard.stack is None:
490
                guard.stack = CapturedTraceback.extract(skip=1 + skip)
491
            if guard.user_stack is None:
492
                guard.user_stack = TracingContext.extract_stack()
493
        self.inner.add(guard)
494

495
    def update(self, *others: Set[Guard]):
496
        for o in others:
497
            for g in o:
498
                self.add(g, skip=1)
499

500

501
class GuardsContext(Checkpointable[GuardsCheckpointState]):
502
    def __init__(self):
503
        self.dynamo_guards: GuardsSet = GuardsSet()
504
        self.aotautograd_guards: List[GuardEnvExpr] = []
505

506
    def copy_graphstate(self):
507
        return GuardsCheckpointState(set(self.dynamo_guards.inner))
508

509
    def restore_graphstate(self, state):
510
        # NB: "steals" the passed in state
511
        assert isinstance(state, GuardsCheckpointState)
512
        self.dynamo_guards = GuardsSet(state.dynamo_guards)
513

514

515
_TLS = threading.local()
516

517
"""
518
TracingContext is the source of truth for all currently accumulated information
519
needed to trace. Its lifecycle is kept 1:1 when using TorchDynamo, but other systems
520
are open to managing their own TracingContext with that in mind.
521

522
The purpose of TracingContext is not to be a dumping ground, or god object, but rather to avoid
523
having to plumb complex subsystems across multiple verticals.
524

525
Ex: A common example is guard accumulation between dynamo, shape_env, aot_autograd, and inductor.
526
Accessing the current tracing context via
527
TracingContext.get() allows users to accumulate their own guards for processing, without needing to know how
528
to plumb objects back up to where frame interpretation happened.
529

530
Note that you can end up with multiple TracingContext for a single compilation
531
of a frame, as we reset the TracingContext whenever we restart analysis.
532
CompileContext is a more overarching context that encompasses multiple restarts.
533
"""
534

535

536
class CompileContext:
537
    @staticmethod
538
    def get() -> CompileContext:
539
        assert _TLS.compile_context is not None
540
        return _TLS.compile_context
541

542
    @staticmethod
543
    def try_get() -> Optional[CompileContext]:
544
        return getattr(_TLS, "compile_context", None)
545

546
    def __init__(self, compile_id):
547
        assert compile_id is None or isinstance(compile_id, CompileId)
548
        self.compile_id: Optional[CompileId] = compile_id
549
        self.attempt = 0
550

551
    @staticmethod
552
    def current_compile_id():
553
        self = CompileContext.try_get()
554
        if self is None:
555
            return None
556
        return self.compile_id
557

558
    @staticmethod
559
    def current_trace_id():
560
        self = CompileContext.try_get()
561
        if self is None:
562
            return None
563
        if self.compile_id is None:
564
            return None
565
        return TraceId(self.compile_id, self.attempt)
566

567

568
class TracingContext:
569
    """
570
    Provides the currently installed TracingContext, or None.
571

572
    Note that it is a staticmethod, and invocations outside of `with tracing()` (see below), are valid but
573
    will return None.
574
    """
575

576
    @staticmethod
577
    def try_get() -> Optional[TracingContext]:
578
        return getattr(_TLS, "tracing_context", None)
579

580
    @staticmethod
581
    def get() -> TracingContext:
582
        if ctx := TracingContext.try_get():
583
            return ctx
584
        raise RuntimeError(
585
            "TracingContext.get() must be called within an ongoing trace."
586
        )
587

588
    def __init__(self, fake_mode):
589
        self.guards_context = GuardsContext()
590
        self.module_context = ModuleContext()
591
        self.global_context = GlobalContext()
592
        self.fake_mode = fake_mode
593
        self.frame_summary_stack = []
594
        # This is morally part of frame_summary_stack, but it is kept separate
595
        # for clarity.  As we process a frame, this variable gets updated
596
        # to keep track of what line we are in the function.  We make a
597
        # function call, this gets cleared and the frame location is pushed
598
        # to frame_summary_stack (prepping this variable for the inner frame's
599
        # progress)
600
        self.loc_in_frame = None
601
        # this is only set after aot_autograd
602
        self.fw_metadata = None
603
        self.params_flat = None
604
        # this is for extended return calling convention from backend
605
        # compiler to aot_autograd
606
        # Per output, what the compiler specified stride of the output is,
607
        # or None if no stride is known.  This is always the HINT, it
608
        # is never a SymInt (it would be better if it was a SymInt, but
609
        # I can't conveniently get this from Inductor atm.  Also, be
610
        # careful not to accidentally induce guards on the SymInt if
611
        # you ever do change this in aot_autograd.py; you should check
612
        # on permutations preferentially.)
613
        self.output_strides: Optional[List[Optional[List[int]]]] = None
614
        # When this is True, whenever we encounter an int in Dynamo tracing,
615
        # we will (1) force unspec it and (2) force it as a size-like unbacked
616
        # integer.  This is currently used when processing certain lists of
617
        # ints that are known to be size-like and may have 0/1 entries that we
618
        # must not specialize on.
619
        self.force_unspec_int_unbacked_size_like = False
620
        # See note [Tensor Fakification and Symbol Caching]
621
        self.tensor_to_context = WeakTensorKeyDictionary()
622

623
        # If this true, Aot Autograd will return output Fake Tensors with appropiate
624
        # meta on the first invocation
625
        # see note: [Returning Fake Tensors on First AOT Autograd Call]
626
        self.fakify_first_call = False
627

628
    def clear(self):
629
        # Look at the note in output_graph.py in function `save_global_state`
630
        # for the context on clearing global context.
631
        self.global_context.global_state = {}
632

633
    @staticmethod
634
    @contextmanager
635
    def patch(**kwargs):
636
        prior = {}
637
        ctx = TracingContext.get()
638

639
        for key in kwargs.keys():
640
            # KeyError on invalid entry
641
            prior[key] = getattr(ctx, key)
642
        for key, val in kwargs.items():
643
            setattr(ctx, key, val)
644
        try:
645
            yield
646
        finally:
647
            for key, val in prior.items():
648
                setattr(ctx, key, val)
649

650
    @staticmethod
651
    def extract_stack():
652
        self = TracingContext.try_get()
653
        if self is None:
654
            return traceback.StackSummary()
655
        stack = self.frame_summary_stack
656
        if self.loc_in_frame is not None:
657
            stack = stack + [self.loc_in_frame]
658
        return traceback.StackSummary.from_list(stack)
659

660
    # Call this when you want to call into some code that isn't necessarily
661
    # associated with the current frame state
662
    @staticmethod
663
    @contextlib.contextmanager
664
    def clear_frame():
665
        tc = TracingContext.get()
666
        with unittest.mock.patch.object(
667
            tc, "frame_summary_stack", []
668
        ), unittest.mock.patch.object(tc, "loc_in_frame", None):
669
            try:
670
                yield
671
            except Exception as e:
672
                # Prevent real_stack from getting attached
673
                #
674
                # The invariant is that if an Exception as real_stack, we've
675
                # appropriately attached a user stack and we no longer need to
676
                # attach anything. Because we cannot conveniently interpose
677
                # when an exception is thrown, we instead interpose everywhere
678
                # we set what the user stack is set (using the context
679
                # manager). However, our compiler stack does "tail calls"
680
                # (when it calls into user compiler), at which point the
681
                # parent exception frames would incorrectly attach an
682
                # incorrect frame.
683
                #
684
                # However, if, somehow, someone raised an exception with this
685
                # scope that had a stack (for example, because they are
686
                # restoring the user stack state appropriately as they process
687
                # node by node), we should respect it. Thus, we cannot
688
                # unconditionally set None.
689
                if not hasattr(e, "real_stack"):
690
                    e.real_stack = None  # type: ignore[attr-defined]
691
                raise
692

693
    @staticmethod
694
    @contextlib.contextmanager
695
    def current_frame(frame_summary):
696
        # frame_summary can be None to solely take advantage of real_stack
697
        # attachment to thrown exceptions
698
        tc = TracingContext.get()
699
        if frame_summary is not None:
700
            tc.frame_summary_stack.append(frame_summary)
701
        old = tc.loc_in_frame
702
        tc.loc_in_frame = None
703
        try:
704
            yield
705
        except Exception as e:
706
            if not hasattr(e, "real_stack"):
707
                e.real_stack = tc.extract_stack()  # type: ignore[attr-defined]
708
            raise
709
        finally:
710
            if frame_summary is not None:
711
                tc.frame_summary_stack.pop()
712
            tc.loc_in_frame = old
713

714
    @staticmethod
715
    @contextlib.contextmanager
716
    def report_output_strides():
717
        tc = TracingContext.try_get()
718
        if tc is None:
719
            yield None
720
            return
721
        old_output_strides = tc.output_strides
722
        tc.output_strides = []
723
        try:
724
            yield tc.output_strides
725
        finally:
726
            tc.output_strides = old_output_strides
727

728
    @staticmethod
729
    def set_current_loc(filename, lineno, frame_name):
730
        TracingContext.get().loc_in_frame = traceback.FrameSummary(
731
            filename, lineno, frame_name
732
        )
733

734

735
@contextmanager
736
def compile_context(context: CompileContext):
737
    old_context = getattr(_TLS, "compile_context", None)
738
    _TLS.compile_context = context
739
    try:
740
        yield context
741
    finally:
742
        _TLS.compile_context = old_context
743

744

745
@contextmanager
746
def tracing(context: Optional[TracingContext]):
747
    """
748
    This function installs the passed in tracing context as a dynamic scoped
749
    global variable.
750

751
    Calls to TracingContext.get() while not under a `with tracing()` context
752
    will return None.
753
    """
754
    old_context = getattr(_TLS, "tracing_context", None)
755
    _TLS.tracing_context = context
756
    try:
757
        yield context
758
    except Exception as e:
759
        if not hasattr(e, "real_stack") and context is not None:
760
            e.real_stack = context.extract_stack()  # type: ignore[attr-defined]
761
        raise
762
    finally:
763
        if (
764
            context is not None
765
            and context.fake_mode is not None
766
            and context.fake_mode.shape_env is not None
767
        ):
768
            context.fake_mode.shape_env.cleanup()
769
        _TLS.tracing_context = old_context
770

771

772
# Subclasses can be found in torch/_dynamo/source.py
773
# TODO(voz): Consider a toplevel torch/_source.py
774
@dataclasses.dataclass(frozen=True)
775
class Source:
776
    def is_dict_key(self):
777
        return False
778

779
    def reconstruct(self, codegen):
780
        raise NotImplementedError()
781

782
    def guard_source(self) -> GuardSource:
783
        raise NotImplementedError()
784

785
    def name(self) -> str:
786
        raise NotImplementedError()
787

788
    def make_guard(self, fn) -> Guard:
789
        if self.guard_source() is GuardSource.CONSTANT:
790
            raise NotImplementedError()
791
        return Guard(self, fn)
792

793
    def is_nn_module(self) -> bool:
794
        return self.guard_source().is_nn_module()
795

796

797
# Subclasses can be found in torch/_dynamo/source.py
798
@dataclasses.dataclass(frozen=True)
799
class ChainedSource(Source):
800
    base: Source
801

802
    def is_dict_key(self):
803
        # Recurse until you either hit a ConstDictKey or a Source
804
        return self.base.is_dict_key()
805

806

807
def detect_fake_mode(inputs: Any = None):
808
    """
809
    Attempts to "detect" what the current fake mode is.  If there is one ambiently
810
    available from TracingContext, we preferentially use that.  Otherwise, we
811
    heuristically detect the fake mode via the following sources, in order of
812
    priority:
813

814
        - Currently active fake mode on stack
815
        - Fake mode associated with passed in tensors (inputs does not
816
          have to be flattened)
817
    """
818
    from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
819

820
    fake_modes = []
821

822
    if context := TracingContext.try_get():
823
        fake_mode = context.fake_mode
824
        if fake_mode is not None:
825
            fake_modes.append((fake_mode, "tracing context", 0))
826

827
    from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
828

829
    for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
830
        if isinstance(m, FakeTensorMode):
831
            fake_modes.append((m, "active fake mode", i))
832

833
    flat_inputs = pytree.tree_leaves(inputs)
834
    for i, flat_input in enumerate(flat_inputs):
835
        if isinstance(flat_input, FakeTensor):
836
            fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
837

838
    if fake_modes:
839
        fake_mode, desc1, i1 = fake_modes[0]
840
        for m, desc2, i2 in fake_modes[1:]:
841
            assert fake_mode is m, (
842
                f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
843
                f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
844
                f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
845
            )
846
        return fake_mode
847
    else:
848
        return None
849

850

851
def active_fake_mode():
852
    """
853
    Inspects the dispatch mode stack for an active fake mode and returns it.
854
    Returns None if no fake mode is active.
855
    """
856
    from torch._subclasses.fake_tensor import FakeTensorMode
857
    from torch.utils._python_dispatch import _get_current_dispatch_mode_stack
858

859
    for _, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
860
        if isinstance(m, FakeTensorMode):
861
            return m
862

863
    return None
864

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.