pytorch

Форк
0
/
_jit_internal.py 
1510 строк · 51.9 Кб
1
"""
2
The weak_script annotation needs to be here instead of inside torch/jit/ so it
3
can be used in other places in torch/ (namely torch.nn) without running into
4
circular dependency problems
5
"""
6

7
import ast
8
import builtins
9
import collections
10
import contextlib
11
import enum
12
import inspect
13
import io
14
import pickle
15
import sys
16
import threading
17
import types
18
import typing
19
import warnings
20
import weakref
21
from textwrap import dedent
22
from typing import (  # noqa: F401
23
    Any,
24
    Callable,
25
    Dict,
26
    Final,
27
    ForwardRef,
28
    Generic,
29
    get_args,  # new in 3.8
30
    get_origin,  # new in 3.8
31
    List,
32
    Optional,
33
    Tuple,
34
    Type,
35
    TypeVar,
36
    Union,
37
)
38

39
import torch
40

41
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
42
# Explicitly ask to import `torch.distributed.__init__` first.
43
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
44
import torch.distributed.rpc
45
import torch.package._mangling as package_mangling
46
from torch._awaits import _Await
47
from torch._C import _Await as CAwait, Future as CFuture
48
from torch._sources import fake_range, get_source_lines_and_file, parse_def
49
from torch.futures import Future
50

51
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
52
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
53

54
BuiltinUnionType: Union[Type, Tuple[Type, ...]]
55
if sys.version_info >= (3, 10):
56
    # NOTE: IS_PY310_PLUS doesn't work with mypy.
57
    # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
58
    BuiltinUnionType = types.UnionType
59
else:
60
    BuiltinUnionType = ()  # trick: this makes isinstance short circuit.
61

62
LockType: Type
63
try:
64
    import _thread
65

66
    LockType = _thread.LockType
67
except ImportError:
68
    import _dummy_thread  # type: ignore[import-not-found]
69

70
    LockType = _dummy_thread.LockType
71

72
# Wrapper functions that can call either of 2 functions depending on a boolean
73
# argument
74
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
75
    weakref.WeakKeyDictionary()
76
)  # noqa: T484
77

78

79
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
80

81

82
class SourceLoader:
83
    def __init__(self):
84
        self.content = {}
85

86
    def cache(self, fn, source):
87
        self.content[fn] = source
88

89
    def get_source(self, fn):
90
        return self.content.get(fn)
91

92

93
loader = SourceLoader()
94

95

96
def createResolutionCallbackFromEnv(lookup_base):
97
    """
98
    Creates a resolution callback that will look up qualified names in an
99
    environment, starting with `lookup_base` for the base of any qualified
100
    names, then proceeding down the lookup chain with the resolved object.
101

102
    You should not use this directly, it should only be used from the other
103
    createResolutionCallbackFrom* functions.
104
    """
105

106
    def lookupInModule(qualified_name, module):
107
        if "." in qualified_name:
108
            base, remaining_pieces = qualified_name.split(".", maxsplit=1)
109
            module_value = getattr(module, base)
110
            return lookupInModule(remaining_pieces, module_value)
111
        else:
112
            return getattr(module, qualified_name)
113

114
    def parseNestedExpr(expr, module) -> Tuple[Any, int]:
115
        i = 0
116
        while i < len(expr) and expr[i] not in (",", "[", "]"):
117
            i += 1
118

119
        # Special case logic for the empty Tuple as a subscript (used
120
        # in the type annotation `Tuple[()]`)
121
        if expr[:i] == "()":
122
            return (), i
123

124
        base = lookupInModule(expr[:i].strip(), module)
125
        assert base is not None, f"Unresolvable type {expr[:i]}"
126
        if i == len(expr) or expr[i] != "[":
127
            return base, i
128

129
        assert expr[i] == "["
130
        parts = []
131
        while expr[i] != "]":
132
            part_len = 0
133
            i += 1
134
            part, part_len = parseNestedExpr(expr[i:], module)
135
            parts.append(part)
136
            i += part_len
137
        if len(parts) > 1:
138
            return base[tuple(parts)], i + 1
139
        else:
140
            return base[parts[0]], i + 1
141

142
    def parseExpr(expr, module):
143
        try:
144
            value, len_parsed = parseNestedExpr(expr, module)
145
            assert len_parsed == len(
146
                expr
147
            ), "whole expression was not parsed, falling back to c++ parser"
148
            return value
149
        except Exception:
150
            """
151
            The python resolver fails in several cases in known unit tests, and is intended
152
            to fall back gracefully to the c++ resolver in general.  For example, python 2 style
153
            annotations which are frequent in our unit tests often fail with types e.g. int not
154
            resolvable from the calling frame.
155
            """
156
            return None
157

158
    return lambda expr: parseExpr(expr, lookup_base)
159

160

161
def createResolutionCallbackFromFrame(frames_up: int = 0):
162
    """
163
    Creates a function which, given a string variable name,
164
    returns the value of the variable in the scope of the caller of
165
    the function which called createResolutionCallbackFromFrame (by default).
166

167
    This is used to enable access in-scope Python variables inside
168
    TorchScript fragments.
169

170
    frames_up is number of additional frames to go up on the stack.
171
    The default value is 0, which correspond to the frame of the caller
172
    of createResolutionCallbackFromFrame. Also for example, if frames_up is set
173
    to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
174
    will be taken.
175

176
    For example, the following program prints 2::
177

178
        def bar():
179
            cb = createResolutionCallbackFromFrame(1)
180
            print(cb("foo"))
181

182
        def baz():
183
            foo = 2
184
            bar()
185

186
        baz()
187
    """
188
    frame = inspect.currentframe()
189
    i = 0
190
    while i < frames_up + 1:
191
        assert frame is not None
192
        frame = frame.f_back
193
        i += 1
194

195
    assert frame is not None
196
    f_locals = frame.f_locals
197
    f_globals = frame.f_globals
198

199
    class env:
200
        def __getattr__(self, key):
201
            if key in f_locals:
202
                return f_locals[key]
203
            elif key in f_globals:
204
                return f_globals[key]
205
            elif key in dir(builtins):
206
                return getattr(builtins, key)
207

208
    return createResolutionCallbackFromEnv(env())
209

210

211
def get_closure(fn):
212
    """
213
    Get a dictionary of closed over variables from a function
214
    """
215
    captures = {}
216
    captures.update(fn.__globals__)
217

218
    for index, captured_name in enumerate(fn.__code__.co_freevars):
219
        captures[captured_name] = fn.__closure__[index].cell_contents
220

221
    return captures
222

223

224
# [local resolution in python]
225
# Depending on where a variable is defined, and where it is used, we may
226
# or may not be able to recover its value when recursively compiling a
227
# script function. Remember in the general case, a module or function is
228
# first defined and then later scripted. This means we do not have a
229
# chance to capture the active frames when the function is defined. Hence any
230
# name resolution has to happen later on the created closure. The way
231
# python captures type annotations restricts what we can recover. The
232
# follow example illustrates the different cases:
233
#
234
#         class MyGlobalClass:
235
#         ...
236
#         def my_local_scope():
237
#             @torch.jit.script
238
#             class MyClass:
239
#                 ...
240
#             @torch.jit.script
241
#             class MyClassUsedAsVar:
242
#                 ...
243
#             def eg(x: MyClass, y: MyGlobalClass):
244
#                 a_local_capture : Foo
245
#                 return MyClassUsedAsVar(x)
246
#
247
# MyGlobalClass is defined in the __globals__ dictionary of function
248
# 'eg', so it is always recoverable. my_local_scope introduces a new local
249
# variable scope in the function. Classes defined here are only visible as
250
# local variables. For the case of MyClassUsedAsVar, it is captured
251
# because it is used as a variable inside the body of the function, and we
252
# can resolve it using the captures returned from `get_closure`. However,
253
# the type annotations are not captured by the closure. In Python
254
# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
255
# annotations on `eg``, but starting in Python 4.0, they will represented as
256
# strings and no longer present. Furthermore, since the body of `eg` does
257
# not reference those names, they do not appear in the list of closed over
258
# variables. In Python 2.x, type annotations are in comments, leading to a
259
# similar situation where their definitions are not available. We anticipate
260
# that most users will not run into this issue because their modules and
261
# functions will be defined at a global scope like MyGlobalClass. In cases
262
# where they are not, it is possible to work around issues by declaring the
263
# values global in the function.
264
# In Python 3.9 declaring class as global will make it invisible to
265
# `inspect.getsource`, see https://bugs.python.org/issue42666 .
266
# This could be worked around by manualy adding it to `global()` dictionary.
267

268

269
def createResolutionCallbackFromClosure(fn):
270
    """
271
    Create a resolutionCallback by introspecting the function instead of
272
    looking up the stack for the enclosing scope
273
    """
274
    closure = get_closure(fn)
275

276
    class closure_lookup:
277
        # This is a class since `closure` is a dict and it's easier in
278
        # `env_helper` if everything just works with `getattr` calls
279
        def __getattr__(self, key):
280
            if key in closure:
281
                return closure[key]
282
            elif hasattr(typing, key):
283
                return getattr(typing, key)
284
            elif hasattr(builtins, key):
285
                return getattr(builtins, key)
286
            return None
287

288
    return createResolutionCallbackFromEnv(closure_lookup())
289

290

291
def can_compile_class(cls) -> bool:
292
    # If any of the functions on a type don't have a code object, this type can't
293
    # be compiled and is probably a builtin / bound from C
294
    if is_ignored_fn(cls):
295
        return False
296

297
    # Ignore the following list of built-in classes.
298
    ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
299
    if issubclass(cls, ignored_builtin_classes):
300
        return False
301

302
    names = cls.__dict__
303
    fns = [
304
        getattr(cls, name)
305
        for name in names
306
        if inspect.isroutine(getattr(cls, name, None))
307
    ]
308
    has_code = [hasattr(fn, "__code__") for fn in fns]
309
    return all(has_code)
310

311

312
def get_callable_argument_names(fn) -> List[str]:
313
    """
314
    Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
315
    Returns an empty list when other types of arguments are present.
316

317
    This is used by `torch.jit.trace` to assign meaningful argument names to
318
    traced functions and modules.
319

320
    Args:
321
        fn: A callable.
322
    Returns:
323
        Argument names: List[str]
324
    """
325
    # inspect.signature may fail, give up in that case.
326
    try:
327
        callable_signature = inspect.signature(fn)
328
    except Exception:
329
        return []
330

331
    argument_names = []
332
    for name, param in callable_signature.parameters.items():
333
        # All four other types of arguments do not map to individual values
334
        # with a keyword as name.
335
        if not param.kind == param.POSITIONAL_OR_KEYWORD:
336
            continue
337

338
        argument_names.append(name)
339

340
    return argument_names
341

342

343
def get_annotation_str(annotation):
344
    """
345
    Convert an AST node containing a type annotation to the string present in the source
346
    that represents the same annotation.
347
    """
348
    if isinstance(annotation, ast.Name):
349
        return annotation.id
350
    elif isinstance(annotation, ast.Attribute):
351
        return ".".join([get_annotation_str(annotation.value), annotation.attr])
352
    elif isinstance(annotation, ast.Subscript):
353
        # In Python3.9+ subscript indicies are not wrapped in ast.Index
354
        subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value  # type: ignore[attr-defined]
355
        return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
356
    elif isinstance(annotation, ast.Tuple):
357
        return ",".join([get_annotation_str(elt) for elt in annotation.elts])
358
    elif isinstance(annotation, (ast.Constant, ast.NameConstant)):
359
        return f"{annotation.value}"
360

361
    # If an AST node is not handled here, it's probably handled in ScriptTypeParser.
362
    return None
363

364

365
def get_type_hint_captures(fn):
366
    """
367
    Get a dictionary containing type resolution mappings necessary to resolve types
368
    for the literal annotations on 'fn'. These are not considered to be closed-over by fn
369
    and must be obtained separately (e.g. using this function).
370

371
    Args:
372
        fn: A callable.
373
    Returns:
374
        A Dict[str, Any] containing a mapping from the literal annotations used on
375
        fn to the Python objects they refer to.
376
    """
377
    # First, try to get the source of the function. We'll need to parse it to find the actual string names
378
    # that were used to annotate the types, since inspect.signature() will only return the class object that
379
    # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
380
    # This may happen in cases where the function is synthesized dynamically at runtime.
381
    src = loader.get_source(fn)
382
    if src is None:
383
        src = inspect.getsource(fn)
384

385
    # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
386
    # types are strings. These are only understood by TorchScript in the context of a type annotation
387
    # that refers to a class in its own definition, but trying to include a mapping for this in the result
388
    # function would cause infinite recursion because the class is currently being compiled.
389
    # In addition, there is logic in ScriptTypeParser to handle this.
390
    signature = inspect.signature(fn)
391
    name_to_type = {
392
        name: parameter.annotation
393
        for name, parameter in signature.parameters.items()
394
        if parameter.annotation is not inspect.Parameter.empty
395
        and not isinstance(parameter.annotation, str)
396
    }
397

398
    # Then, get the literal type annotations from the function declaration
399
    # by source inspection. This accounts for the case in which aliases are used
400
    # to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
401
    # frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
402
    a = ast.parse(dedent(src))
403
    if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
404
        raise RuntimeError(f"Expected {fn} to be a function")
405
    f = a.body[0]
406

407
    # Prepare a dictionary of source annotation -> type, which will be the final result of this function,
408
    # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
409
    # them to the type object corresponding to the annotation via name_to_type using the parameter name.
410
    annotation_to_type = {}
411

412
    for arg in f.args.args:
413
        # Get the source type annotation string for this argument if possible.
414
        arg_annotation_str = (
415
            get_annotation_str(arg.annotation) if arg.annotation else None
416
        )
417

418
        # If the argument has no annotation or get_annotation_str cannot convert it to a string,
419
        # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
420
        # this in the latter case.
421
        if arg_annotation_str is None:
422
            continue
423

424
        # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
425
        # be present in name_to_type is that the annotation itself is a string and not a type object
426
        # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
427
        arg_name = arg.arg
428
        if arg_name in name_to_type:
429
            annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
430

431
    # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
432
    # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
433
    # of the annotation cannot be a string.
434
    literal_return_annotation = get_annotation_str(f.returns)
435
    valid_literal_annotation = literal_return_annotation is not None
436
    return_annotation = signature.return_annotation
437
    valid_return_annotation_type = (
438
        return_annotation is not inspect.Parameter.empty
439
        and not isinstance(return_annotation, str)
440
    )
441
    if valid_literal_annotation and valid_return_annotation_type:
442
        annotation_to_type[literal_return_annotation] = return_annotation
443

444
    return annotation_to_type
445

446

447
def createResolutionCallbackForClassMethods(cls):
448
    """
449
    This looks at all the methods defined in a class and pulls their closed-over
450
    variables into a dictionary and uses that to resolve variables.
451
    """
452
    # cls is a type here, so `ismethod` is false since the methods on the type
453
    # aren't bound to anything, so Python treats them as regular functions
454
    fns = [
455
        getattr(cls, name)
456
        for name in cls.__dict__
457
        if inspect.isroutine(getattr(cls, name))
458
    ]
459
    # Skip built-ins, as they do not have global scope nor type hints
460
    # Needed to support `enum.Enum` derived classes in Python-3.11
461
    # That adds `_new_member_` property which is an alias to `__new__`
462
    fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
463
    captures = {}
464

465
    for fn in fns:
466
        captures.update(get_closure(fn))
467
        captures.update(get_type_hint_captures(fn))
468

469
    def lookup_in_class(key):
470
        if key in captures:
471
            return captures[key]
472
        else:
473
            return getattr(builtins, key, None)
474

475
    return lookup_in_class
476

477

478
def boolean_dispatch(
479
    arg_name, arg_index, default, if_true, if_false, module_name, func_name
480
):
481
    """
482
    Dispatches to either of 2 script functions based on a boolean argument.
483
    In TorchScript, the boolean argument must be constant so that the correct
484
    function to use can be determined at compile time.
485
    """
486

487
    def fn(*args, **kwargs):
488
        dispatch_flag = default
489
        if arg_name in kwargs:
490
            dispatch_flag = kwargs[arg_name]
491
        elif arg_index < len(args):
492
            dispatch_flag = args[arg_index]
493

494
        if dispatch_flag:
495
            return if_true(*args, **kwargs)
496
        else:
497
            return if_false(*args, **kwargs)
498

499
    if if_true.__doc__ is None and if_false.__doc__ is not None:
500
        doc = if_false.__doc__
501
        if_true.__doc__ = doc
502
    elif if_false.__doc__ is None and if_true.__doc__ is not None:
503
        doc = if_true.__doc__
504
        if_false.__doc__ = doc
505
    elif if_false.__doc__ is None and if_true.__doc__ is None:
506
        # neither function has a docstring
507
        doc = None
508
    else:
509
        raise RuntimeError("only one function can have a docstring")
510
    fn.__doc__ = doc
511

512
    if module_name is not None:
513
        fn.__module__ = module_name
514
    if func_name is not None:
515
        fn.__name__ = func_name
516

517
    boolean_dispatched[fn] = {
518
        "if_true": if_true,
519
        "if_false": if_false,
520
        "index": arg_index,
521
        "default": default,
522
        "arg_name": arg_name,
523
    }
524
    return fn
525

526

527
class FunctionModifiers:
528
    """
529
    Used to denote the behavior of a function in TorchScript. See export() and
530
    ignore() for details.
531
    """
532

533
    UNUSED = "unused (ignored and replaced with raising of an exception)"
534
    IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
535
    EXPORT = "export (compile this function even if nothing calls it)"
536
    DEFAULT = "default (compile if called from a exported function / forward)"
537
    COPY_TO_SCRIPT_WRAPPER = (
538
        "if this method is not scripted, copy the python method onto the scripted model"
539
    )
540
    _DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
541

542

543
def export(fn):
544
    """
545
    This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
546
    :class:`ScriptModule` and should be compiled.
547

548
    ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
549
    Functions and methods called from ``forward`` are compiled as they are seen
550
    by the compiler, so they do not need this decorator either.
551

552
    Example (using ``@torch.jit.export`` on a method):
553

554
    .. testcode::
555

556
        import torch
557
        import torch.nn as nn
558

559
        class MyModule(nn.Module):
560
            def implicitly_compiled_method(self, x):
561
                return x + 99
562

563
            # `forward` is implicitly decorated with `@torch.jit.export`,
564
            # so adding it here would have no effect
565
            def forward(self, x):
566
                return x + 10
567

568
            @torch.jit.export
569
            def another_forward(self, x):
570
                # When the compiler sees this call, it will compile
571
                # `implicitly_compiled_method`
572
                return self.implicitly_compiled_method(x)
573

574
            def unused_method(self, x):
575
                return x - 20
576

577
        # `m` will contain compiled methods:
578
        #     `forward`
579
        #     `another_forward`
580
        #     `implicitly_compiled_method`
581
        # `unused_method` will not be compiled since it was not called from
582
        # any compiled methods and wasn't decorated with `@torch.jit.export`
583
        m = torch.jit.script(MyModule())
584
    """
585
    fn._torchscript_modifier = FunctionModifiers.EXPORT
586
    return fn
587

588

589
def unused(fn):
590
    """
591
    This decorator indicates to the compiler that a function or method should
592
    be ignored and replaced with the raising of an exception. This allows you
593
    to leave code in your model that is not yet TorchScript compatible and still
594
    export your model.
595

596
        Example (using ``@torch.jit.unused`` on a method)::
597

598
            import torch
599
            import torch.nn as nn
600

601
            class MyModule(nn.Module):
602
                def __init__(self, use_memory_efficient):
603
                    super().__init__()
604
                    self.use_memory_efficient = use_memory_efficient
605

606
                @torch.jit.unused
607
                def memory_efficient(self, x):
608
                    import pdb
609
                    pdb.set_trace()
610
                    return x + 10
611

612
                def forward(self, x):
613
                    # Use not-yet-scriptable memory efficient mode
614
                    if self.use_memory_efficient:
615
                        return self.memory_efficient(x)
616
                    else:
617
                        return x + 10
618

619
            m = torch.jit.script(MyModule(use_memory_efficient=False))
620
            m.save("m.pt")
621

622
            m = torch.jit.script(MyModule(use_memory_efficient=True))
623
            # exception raised
624
            m(torch.rand(100))
625
    """
626
    if isinstance(fn, property):
627
        prop = fn
628
        setattr(  # noqa: B010
629
            prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
630
        )
631

632
        if prop.fset:
633
            setattr(  # noqa: B010
634
                prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
635
            )
636

637
        return prop
638

639
    fn._torchscript_modifier = FunctionModifiers.UNUSED
640
    return fn
641

642

643
# No op context manager from python side
644
class _IgnoreContextManager(contextlib.AbstractContextManager):
645
    def __init__(self, **kwargs):
646
        pass
647

648
    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
649
        pass
650

651

652
def ignore(drop=False, **kwargs):
653
    """
654
    This decorator indicates to the compiler that a function or method should
655
    be ignored and left as a Python function. This allows you to leave code in
656
    your model that is not yet TorchScript compatible. If called from TorchScript,
657
    ignored functions will dispatch the call to the Python interpreter. Models with ignored
658
    functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
659

660
    Example (using ``@torch.jit.ignore`` on a method)::
661

662
        import torch
663
        import torch.nn as nn
664

665
        class MyModule(nn.Module):
666
            @torch.jit.ignore
667
            def debugger(self, x):
668
                import pdb
669
                pdb.set_trace()
670

671
            def forward(self, x):
672
                x += 10
673
                # The compiler would normally try to compile `debugger`,
674
                # but since it is `@ignore`d, it will be left as a call
675
                # to Python
676
                self.debugger(x)
677
                return x
678

679
        m = torch.jit.script(MyModule())
680

681
        # Error! The call `debugger` cannot be saved since it calls into Python
682
        m.save("m.pt")
683

684
    Example (using ``@torch.jit.ignore(drop=True)`` on a method):
685

686
    .. testcode::
687

688
        import torch
689
        import torch.nn as nn
690

691
        class MyModule(nn.Module):
692
            @torch.jit.ignore(drop=True)
693
            def training_method(self, x):
694
                import pdb
695
                pdb.set_trace()
696

697
            def forward(self, x):
698
                if self.training:
699
                    self.training_method(x)
700
                return x
701

702
        m = torch.jit.script(MyModule())
703

704
        # This is OK since `training_method` is not saved, the call is replaced
705
        # with a `raise`.
706
        m.save("m.pt")
707

708
    .. testcleanup::
709

710
        import os
711
        os.remove('m.pt')
712
    """
713

714
    if callable(drop):
715
        # used without any args, so drop is actually a function
716
        #   @torch.jit.ignore
717
        #   def fn(...):
718
        fn = drop
719
        fn._torchscript_modifier = FunctionModifiers.IGNORE
720
        return fn
721

722
    if not isinstance(drop, bool):
723
        raise RuntimeError(
724
            "Argument to @torch.jit.ignore must be a bool or "
725
            f"a function but got {drop}"
726
        )
727

728
    # for backwards compat
729
    drop_on_export = kwargs.pop("drop_on_export", None)
730
    if drop_on_export:
731
        warnings.warn(
732
            "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
733
            "call on compilation. Use torch.jit.unused now. {}",
734
            category=FutureWarning,
735
        )
736

737
        drop = drop_on_export
738
    elif drop:
739
        warnings.warn(
740
            "ignore(True) has been deprecated. TorchScript will now drop the function "
741
            "call on compilation. Use torch.jit.unused now. {}",
742
            category=FutureWarning,
743
        )
744

745
    def decorator(fn):
746
        if drop:
747
            fn._torchscript_modifier = FunctionModifiers.UNUSED
748
        else:
749
            fn._torchscript_modifier = FunctionModifiers.IGNORE
750
        return fn
751

752
    return decorator
753

754

755
def _drop(fn):
756
    fn._torchscript_modifier = FunctionModifiers._DROP
757
    return fn
758

759

760
def _copy_to_script_wrapper(fn):
761
    fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
762
    return fn
763

764

765
def module_has_exports(mod):
766
    for name in dir(mod):
767
        if hasattr(mod, name):
768
            item = getattr(mod, name)
769
            if callable(item):
770
                if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
771
                    return True
772
    return False
773

774

775
# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
776
# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
777
# allow JIT'd code to still be covered.
778
def should_drop(fn) -> bool:
779
    attr = get_torchscript_modifier(fn)
780
    if attr is None:
781
        return False
782
    return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
783

784

785
def is_ignored_fn(fn) -> bool:
786
    mod = get_torchscript_modifier(fn)
787
    return (
788
        mod is FunctionModifiers.UNUSED
789
        or mod is FunctionModifiers.IGNORE
790
        or mod is FunctionModifiers._DROP
791
    )
792

793

794
def _is_drop_fn(fn) -> bool:
795
    mod = get_torchscript_modifier(fn)
796
    return mod is FunctionModifiers._DROP
797

798

799
def is_static_fn(cls, fn) -> bool:
800
    return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
801

802

803
def get_static_fn(cls, fn):
804
    return inspect.getattr_static(cls, fn).__func__
805

806

807
def get_torchscript_modifier(fn):
808
    if not callable(fn):
809
        return None
810
    if hasattr(fn, "__func__"):
811
        fn = fn.__func__
812
    return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
813

814

815
def copy_torchscript_modifier(orig, new) -> None:
816
    attr = get_torchscript_modifier(orig)
817
    if attr is None:
818
        return
819
    new._torchscript_modifier = attr
820

821

822
# overloading registration
823
# overloads get registered in this file, and compiled in torch/jit/__init__.py
824
# so that they can be imported in nn/functional.py without an import cycle
825

826
# qualified_name => list[overload_functions]
827
_overloaded_fns: Dict[str, List[Callable]] = {}  # noqa: T484
828

829

830
_OVERLOAD_EXAMPLE = """
831
Example usage of overload function:
832
@torch.jit._overload
833
def my_function(x: type0) -> type0: # decl 1
834
    pass
835

836
@torch.jit._overload
837
def my_function(x: type1) -> type1: # decl 2
838
    pass
839

840
def my_function(x):                 # implementation
841
    if isinstance(x, type0):
842
        return x
843
    elif isinstance(x, type1):
844
        return x
845
"""
846

847

848
def get_overload_no_implementation_error_message(kind, obj):
849
    sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
850
    return (
851
        f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
852
        f"sure a definition is provided and defined after all overload declarations.\n"
853
        f'File "{filename}", line {file_lineno}:\n'
854
        + "".join(sourcelines)
855
        + "\n"
856
        + _OVERLOAD_EXAMPLE
857
    )
858

859

860
def _check_overload_body(func):
861
    try:
862
        parsed_def = parse_def(func)
863
    except OSError as e:
864
        # Parsing the function definition can raise an OSError if source is unavailable.
865
        # Since this is just an initial check, just raise a warning if this is the case.
866
        warnings.warn(
867
            f"Unable to retrieve source for @torch.jit._overload function: {func}."
868
        )
869
        return
870

871
    body = parsed_def.ast.body[0].body
872

873
    def is_pass(x):
874
        return isinstance(x, ast.Pass)
875

876
    def is_ellipsis(x):
877
        return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
878

879
    if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
880
        msg = (
881
            "Only `pass` statement or `...` can be the body of overload declaration:\n"
882
        )
883
        msg += "\n".join(parsed_def.source.split("\n")[:3])
884
        msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
885
        raise RuntimeError(msg)
886

887

888
def _overload(func):
889
    _check_overload_body(func)
890
    qual_name = _qualified_name(func)
891
    global _overloaded_fns
892
    fn_overload_list = _overloaded_fns.get(qual_name)
893
    if fn_overload_list is None:
894
        fn_overload_list = []
895
        _overloaded_fns[qual_name] = fn_overload_list
896
    fn_overload_list.append(func)
897
    return func
898

899

900
def _get_fn_overloads(qual_name):
901
    return _overloaded_fns.get(qual_name)
902

903

904
def _clear_fn_overloads(qual_name) -> None:
905
    del _overloaded_fns[qual_name]
906

907

908
def get_class_name_lineno(method) -> Tuple[str, int]:
909
    current_frame = inspect.currentframe()
910

911
    # one for the get_class_name call, one for _overload_method call
912
    for i in range(2):
913
        assert (
914
            current_frame is not None
915
        )  # assert current frame is not an Optional[FrameType]
916
        current_frame = current_frame.f_back
917

918
    assert current_frame is not None  # same here
919
    class_name = current_frame.f_code.co_name
920
    line_no = current_frame.f_code.co_firstlineno
921
    return class_name, line_no
922

923

924
# At the point the decorator is applied to class methods the method
925
# has no reference to its owning class. _qualified_name would not include
926
# the class it is defined in, so any methods with the same name in the same file
927
# would have the same _qualified_name, even if they were defined in different
928
# classes. This problem only exists in python 2.
929
# We get around this problem by looking at the stack frame and identifying
930
# the class name, and throwing an error whenever overloads are used
931
# when modules of the same name are in the same file
932

933
# qualified_name => class name => list[overload_functions]
934
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {}  # noqa: T484
935

936

937
# (qualified_name, class name) => class_fileno
938
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
939

940

941
def _overload_method(func):
942
    _check_overload_body(func)
943
    qual_name = _qualified_name(func)
944
    global _overloaded_methods
945
    class_name_map = _overloaded_methods.get(qual_name, None)
946
    if class_name_map is None:
947
        class_name_map = {}
948
        _overloaded_methods[qual_name] = class_name_map
949

950
    class_name, line_no = get_class_name_lineno(func)
951
    method_overloads = class_name_map.get(class_name, None)
952
    if method_overloads is None:
953
        method_overloads = []
954
        class_name_map[class_name] = method_overloads
955
        _overloaded_method_class_fileno[(qual_name, class_name)] = line_no
956
    else:
957
        existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
958
        if existing_lineno != line_no:
959
            raise RuntimeError(
960
                "Cannot currently overload the same method name in two different"
961
                " classes with the same name in the same module"
962
            )
963

964
    method_overloads.append(func)
965
    return func
966

967

968
def _get_overloaded_methods(method, mod_class):
969
    # TODO: __name__ not set for submodules in recursive script
970
    if not hasattr(method, "__name__"):
971
        return None
972
    qual_name = _qualified_name(method)
973
    class_name_map = _overloaded_methods.get(qual_name, None)
974
    if class_name_map is None:
975
        return None
976
    overloads = class_name_map.get(mod_class.__name__, None)
977
    if overloads is None:
978
        return None
979

980
    method_line_no = get_source_lines_and_file(method)[1]
981
    mod_class_fileno = get_source_lines_and_file(mod_class)[1]
982
    mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
983
    if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
984
        raise Exception(
985
            "Overloads are not useable when a module is redeclared within the same file: "
986
            + str(method)
987
        )
988
    return overloads
989

990

991
def is_tuple(ann) -> bool:
992
    if ann is Tuple:
993
        raise_error_container_parameter_missing("Tuple")
994

995
    # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
996
    if not hasattr(ann, "__module__"):
997
        return False
998

999
    ann_origin = get_origin(ann)
1000
    if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
1001
        return True
1002
    return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
1003

1004

1005
def is_list(ann) -> bool:
1006
    if ann is List:
1007
        raise_error_container_parameter_missing("List")
1008

1009
    if not hasattr(ann, "__module__"):
1010
        return False
1011

1012
    ann_origin = get_origin(ann)
1013
    if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
1014
        return True
1015
    return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
1016

1017

1018
def is_dict(ann) -> bool:
1019
    if ann is Dict:
1020
        raise_error_container_parameter_missing("Dict")
1021

1022
    if not hasattr(ann, "__module__"):
1023
        return False
1024

1025
    ann_origin = get_origin(ann)
1026
    if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
1027
        return True
1028
    return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
1029

1030

1031
def is_union(ann):
1032
    if ann is Union:
1033
        raise_error_container_parameter_missing("Union")
1034

1035
    return isinstance(ann, BuiltinUnionType) or (
1036
        hasattr(ann, "__module__")
1037
        and ann.__module__ == "typing"
1038
        and (get_origin(ann) is Union)
1039
    )
1040

1041

1042
def is_optional(ann):
1043
    if ann is Optional:
1044
        raise_error_container_parameter_missing("Optional")
1045

1046
    def is_optional_as_optional(ann):
1047
        return (
1048
            hasattr(ann, "__module__")
1049
            and ann.__module__ == "typing"
1050
            and (get_origin(ann) is Optional)
1051
        )
1052

1053
    def is_union_as_optional(ann):
1054
        ann_args = get_args(ann)
1055
        return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
1056

1057
    return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
1058

1059

1060
def is_future(ann) -> bool:
1061
    if ann is Future:
1062
        raise RuntimeError(
1063
            "Attempted to use Future without a "
1064
            "contained type. Please add a contained type, e.g. "
1065
            "Future[int]"
1066
        )
1067
    return get_origin(ann) is Future
1068

1069

1070
def is_await(ann) -> bool:
1071
    if ann is _Await:
1072
        return True
1073
    return get_origin(ann) is _Await
1074

1075

1076
if torch.distributed.rpc.is_available():
1077
    from torch._C._distributed_rpc import PyRRef
1078
    from torch.distributed.rpc import RRef
1079

1080
    def is_rref(ann) -> bool:
1081
        if ann is RRef:
1082
            raise RuntimeError(
1083
                "Attempted to use RRef without a "
1084
                "contained type. Please add a contained type, e.g. "
1085
                "RRef[int]"
1086
            )
1087
        return get_origin(ann) is RRef
1088

1089
    def is_rref_instance(obj) -> bool:
1090
        return isinstance(obj, PyRRef)
1091

1092
else:
1093

1094
    def is_rref_instance(obj) -> bool:
1095
        # If the RPC module doesn't exist then RRefs don't exist either.
1096
        return False
1097

1098

1099
def is_final(ann) -> bool:
1100
    return (
1101
        hasattr(ann, "__module__")
1102
        and ann.__module__ in {"typing", "typing_extensions"}
1103
        and (get_origin(ann) is Final or isinstance(ann, type(Final)))
1104
    )
1105

1106

1107
# allows BroadcastingList instance to be subscriptable
1108
class BroadcastingListCls:
1109
    def __getitem__(self, types):
1110
        return
1111

1112

1113
# mypy doesn't support parameters on types, so we have to explicitly type each
1114
# list size
1115
BroadcastingList1 = BroadcastingListCls()
1116
for i in range(2, 7):
1117
    globals()[f"BroadcastingList{i}"] = BroadcastingList1
1118

1119

1120
def is_scripting() -> bool:
1121
    r"""
1122
    Function that returns True when in compilation and False otherwise. This
1123
    is useful especially with the @unused decorator to leave code in your
1124
    model that is not yet TorchScript compatible.
1125
    .. testcode::
1126

1127
        import torch
1128

1129
        @torch.jit.unused
1130
        def unsupported_linear_op(x):
1131
            return x
1132

1133
        def linear(x):
1134
           if torch.jit.is_scripting():
1135
              return torch.linear(x)
1136
           else:
1137
              return unsupported_linear_op(x)
1138
    """
1139
    return False
1140

1141

1142
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
1143
def _qualified_name(obj, mangle_name=True) -> str:
1144
    # This special case allows us to override the qualified name on a type.
1145
    # It's currently used in conjunction with tracing, where we create a
1146
    # fake module to filter only supported attributes. However, since this
1147
    # new type is defined as a local class, we need a mechanism to override
1148
    # its qualname so it appears correctly in the TorchScript system. This,
1149
    # we set '_jit_override_qualname' with the original traced module's
1150
    # qualified name, which is picked up here
1151
    if hasattr(obj, "_jit_override_qualname"):
1152
        return obj._jit_override_qualname
1153
    # short-circuit in cases where the object already has a known qualified name
1154
    if isinstance(obj, torch._C.ScriptFunction):
1155
        return obj.qualified_name
1156

1157
    if getattr(obj, "__name__", None):
1158
        name = obj.__name__
1159
    # Enum classes do not have `__name__` attr, instead they have `name`.
1160
    elif isinstance(obj, enum.Enum):
1161
        name = obj.name
1162
    else:
1163
        raise RuntimeError("Could not get name of python class object")
1164

1165
    if name == "<lambda>":
1166
        name = "_lambda"  # make name a valid identifier
1167

1168
    module_name = obj.__module__
1169

1170
    # If the module is actually a torchbind module, then we should short circuit
1171
    if module_name == "torch._classes":
1172
        return obj.qualified_name
1173

1174
    # The Python docs are very clear that `__module__` can be None, but I can't
1175
    # figure out when it actually would be.
1176
    if module_name is None:
1177
        raise RuntimeError(
1178
            f"Could not get qualified name for class '{name}': "
1179
            "__module__ can't be None."
1180
        )
1181

1182
    # if getattr(sys.modules[module_name], name) is not obj:
1183
    #     raise RuntimeError(f"Could not get qualified name for class '{name}': "
1184
    #                        f"the attr {name} on module {module_name} is not the class")
1185

1186
    # torch.package and TorchScript have separate mangling schemes to avoid
1187
    # name collisions from multiple packages. To avoid them interfering with
1188
    # each other, normalize the package manging here.
1189
    if package_mangling.is_mangled(module_name):
1190
        module_name = module_name.replace("<", "_")
1191
        module_name = module_name.replace(">", "_")
1192

1193
    # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
1194
    # does not need mangle the python class name.
1195
    if mangle_name:
1196
        # __main__ is a builtin module, so rewrite it to "__torch__".
1197
        if module_name == "__main__":
1198
            module_name = "__torch__"
1199
        else:
1200
            # Everything else gets a "__torch__" prefix to avoid name collisions
1201
            # with the names of user values.
1202
            module_name = "__torch__." + module_name
1203

1204
    if "." in name:
1205
        raise RuntimeError(
1206
            f"Could not get qualified name for class '{name}': "
1207
            f"'{name}' is not a valid identifier"
1208
        )
1209

1210
    return module_name + "." + name
1211

1212

1213
def _try_get_dispatched_fn(fn):
1214
    if not callable(fn):
1215
        return None
1216
    return boolean_dispatched.get(fn)
1217

1218

1219
def _get_named_tuple_properties(
1220
    obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
1221
):
1222
    if loc is None:
1223
        loc = fake_range()
1224

1225
    assert issubclass(obj, tuple) and hasattr(obj, "_fields")
1226
    if hasattr(obj, "_field_defaults"):
1227
        defaults = [
1228
            obj._field_defaults[field]
1229
            for field in obj._fields
1230
            if field in obj._field_defaults
1231
        ]
1232
    else:
1233
        defaults = []
1234
    # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
1235
    # Also, annotations from base class are not inherited so they need to be queried explicitly
1236
    if sys.version_info[:2] < (3, 10):
1237
        obj_annotations = getattr(obj, "__annotations__", {})
1238
    else:
1239
        obj_annotations = inspect.get_annotations(obj)
1240
        if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
1241
            obj_annotations = inspect.get_annotations(obj.__base__)
1242

1243
    annotations = []
1244
    for field in obj._fields:
1245
        if field in obj_annotations:
1246
            field_type = obj_annotations[field]
1247
            # [Note: ForwardRef annotations in NamedTuple attributes]
1248
            # NamedTuple types are slightly different from normal types.
1249
            #
1250
            # Normally, annotations are evaluted like this (during jit.script):
1251
            # 1. Load strings of python code into c++ and parse.
1252
            # 2. Get annotations as strings
1253
            # 3. Use the PythonResolver's resolution callback (rcb) to convert
1254
            #    the string into a python object
1255
            # 4. We call into annotations.py:ann_to_type to convert python obj
1256
            #    from step 3 into a type that torchscript understands.
1257
            #
1258
            # NamedTuples are more complicated, because it has sub-types.
1259
            # Normally, once we have the NamedTuple type object from #3,
1260
            # we can just look at the annotation literal values and use
1261
            # ann_to_type directly on them.
1262
            #
1263
            # But sometimes, users will annotate with string literals, e.g.
1264
            #    x: 'int'
1265
            # This also happens with PEP563 (from __forward__ import annotations)
1266
            #
1267
            # These annotations appear in the annotation dict as ForwardRef('int').
1268
            #
1269
            # Then, we need to convert the string into a python object. This
1270
            # requires having local context for custom objects or imported types.
1271
            # rcb() is what gives us this. So, we plumb rcb through the stack so
1272
            # it can be used in this context for the if block below.
1273
            #
1274
            # FAQ:
1275
            # - Why do we need this special handling for NamedTuple but string
1276
            #   annotations work fine for normal types? Normally, we parse the
1277
            #   string directly and then call rcb() directly from C++.
1278
            # - Why not use ForwardRef._evaluate? For that, we need globals()
1279
            #   and locals() for the local context where the NamedTuple was defined.
1280
            #   rcb is what lets us look up into these. So, basically rcb does the
1281
            #   hard work for us.
1282
            if isinstance(field_type, ForwardRef) and rcb is not None:
1283
                rcb_type = rcb(field_type.__forward_arg__)
1284
                # rcb returns None if it can't find anything.
1285
                if rcb_type is None:
1286
                    raise ValueError(
1287
                        f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
1288
                        f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
1289
                        f" Issue occurred at {loc.highlight()}"
1290
                    )
1291
                field_type = rcb_type
1292
            the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
1293
            annotations.append(the_type)
1294
        else:
1295
            annotations.append(torch._C.TensorType.getInferred())
1296
    return type(obj).__name__, obj._fields, annotations, defaults
1297

1298

1299
def _create_named_tuple(
1300
    t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...]
1301
):
1302
    TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults)  # type: ignore[call-arg, no-redef, misc]
1303
    return TupleType(*t)
1304

1305

1306
@contextlib.contextmanager
1307
def _disable_emit_hooks():
1308
    hooks = torch._C._jit_get_emit_hooks()
1309
    torch._C._jit_set_emit_hooks(None, None)
1310
    try:
1311
        yield
1312
    finally:
1313
        torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
1314

1315

1316
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None:  # noqa: F811
1317
    def __enter__(self) -> None:
1318
        self.hooks = torch._C._jit_get_emit_hooks()
1319
        torch._C._jit_set_emit_hooks(None, None)
1320

1321
    def __exit__(self, *args) -> None:
1322
        torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
1323

1324

1325
def _is_exception(obj) -> bool:
1326
    if not inspect.isclass(obj):
1327
        return False
1328
    return issubclass(obj, Exception)
1329

1330

1331
def raise_error_container_parameter_missing(target_type) -> None:
1332
    if target_type == "Dict":
1333
        raise RuntimeError(
1334
            "Attempted to use Dict without "
1335
            "contained types. Please add contained type, e.g. "
1336
            "Dict[int, int]"
1337
        )
1338
    raise RuntimeError(
1339
        f"Attempted to use {target_type} without a "
1340
        "contained type. Please add a contained type, e.g. "
1341
        f"{target_type}[int]"
1342
    )
1343

1344

1345
def check_args_exist(target_type) -> None:
1346
    if target_type is List or target_type is list:
1347
        raise_error_container_parameter_missing("List")
1348
    elif target_type is Tuple or target_type is tuple:
1349
        raise_error_container_parameter_missing("Tuple")
1350
    elif target_type is Dict or target_type is dict:
1351
        raise_error_container_parameter_missing("Dict")
1352
    elif target_type is None or target_type is Optional:
1353
        raise_error_container_parameter_missing("Optional")
1354

1355

1356
def check_empty_containers(obj) -> None:
1357
    if obj == [] or obj == {} or obj == ():
1358
        warnings.warn(
1359
            "The inner type of a container is lost when "
1360
            "calling torch.jit.isinstance in eager mode. For "
1361
            "example, List[int] would become list and "
1362
            "therefore falsely return True for List[float] or"
1363
            " List[str]."
1364
        )
1365

1366

1367
# supports List/Dict/Tuple and Optional types
1368
# TODO support future
1369
def container_checker(obj, target_type) -> bool:
1370
    origin_type = get_origin(target_type)
1371
    check_args_exist(target_type)
1372
    if origin_type is None:
1373
        return False
1374
    elif origin_type is list or origin_type is List:
1375
        check_empty_containers(obj)
1376
        if not isinstance(obj, list):
1377
            return False
1378
        arg_type = get_args(target_type)[0]
1379
        arg_origin = get_origin(arg_type)
1380
        for el in obj:
1381
            # check if nested container, ex: List[List[str]]
1382
            if arg_origin:  # processes nested container, ex: List[List[str]]
1383
                if not container_checker(el, arg_type):
1384
                    return False
1385
            elif not isinstance(el, arg_type):
1386
                return False
1387
        return True
1388
    elif origin_type is Dict or origin_type is dict:
1389
        check_empty_containers(obj)
1390
        if not isinstance(obj, dict):
1391
            return False
1392
        key_type = get_args(target_type)[0]
1393
        val_type = get_args(target_type)[1]
1394
        for key, val in obj.items():
1395
            # check if keys are of right type
1396
            if not isinstance(key, key_type):
1397
                return False
1398
            val_origin = get_origin(val_type)
1399
            if val_origin:
1400
                if not container_checker(val, val_type):
1401
                    return False
1402
            elif not isinstance(val, val_type):
1403
                return False
1404
        return True
1405
    elif origin_type is Tuple or origin_type is tuple:
1406
        check_empty_containers(obj)
1407
        if not isinstance(obj, tuple):
1408
            return False
1409
        arg_types = get_args(target_type)
1410
        if len(obj) != len(arg_types):
1411
            return False
1412
        for el, el_type in zip(obj, arg_types):
1413
            el_origin = get_origin(el_type)
1414
            if el_origin:
1415
                if not container_checker(el, el_type):
1416
                    return False
1417
            elif not isinstance(el, el_type):
1418
                return False
1419
        return True
1420
    elif origin_type is Union or issubclass(
1421
        origin_type, BuiltinUnionType
1422
    ):  # also handles Optional
1423
        if obj is None:  # check before recursion because None is always fine
1424
            return True
1425
        inner_types = get_args(target_type)
1426
        for t in inner_types:
1427
            t_origin = get_origin(t)
1428
            if t_origin:
1429
                return container_checker(obj, t)
1430
            elif isinstance(obj, t):
1431
                return True
1432
    return False
1433

1434

1435
def _isinstance(obj, target_type) -> bool:
1436
    if isinstance(target_type, collections.abc.Container):
1437
        if not isinstance(target_type, tuple):
1438
            raise RuntimeError(
1439
                "The second argument to "
1440
                "`torch.jit.isinstance` must be a type "
1441
                "or a tuple of types"
1442
            )
1443
        for t_type in target_type:
1444
            if _isinstance(obj, t_type):
1445
                return True
1446
        return False
1447

1448
    origin_type = get_origin(target_type)
1449
    if origin_type:
1450
        return container_checker(obj, target_type)
1451

1452
    # Check to handle non-typed optional origin returns as none instead
1453
    #    of as optional in 3.7-3.8
1454
    check_args_exist(target_type)
1455

1456
    # handle non-containers
1457
    return isinstance(obj, target_type)
1458

1459

1460
class _TensorExtractor(pickle.Pickler):
1461
    def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
1462
        super().__init__(*args, **kwargs)
1463
        self.tensors = tensors
1464

1465
    def persistent_id(self, obj):
1466
        if isinstance(obj, torch.Tensor):
1467
            self.tensors.append(obj)
1468
            return ""
1469
        # Since we just want to extract tensors, we don't mind if an object is
1470
        # unpicklable if it doesn't contain tensors, as we can just ignore/skip
1471
        # it. To play it safe, we only do so for common objects that we're sure
1472
        # don't contain tensors. Feel free to add new types here. Note also that
1473
        # even if a type isn't listed here this won't block users, since thet
1474
        # can just add a __getstate__ or __reduce__ method to their class.
1475
        if isinstance(obj, LockType):
1476
            return ""
1477
        # Futures and RRefs don't technically contain a value, they just offer
1478
        # the means to access a value.
1479
        if isinstance(obj, CFuture) or is_rref_instance(obj):
1480
            return ""
1481
        if isinstance(obj, CAwait):
1482
            return ""
1483
        if isinstance(obj, torch.cuda.Event):
1484
            return ""
1485
        if isinstance(obj, threading.Thread):
1486
            return ""
1487
        return None
1488

1489

1490
def _extract_tensors(obj):
1491
    r"""
1492
    This function is exclusively called from C++.
1493
    See ``torch/csrc/jit/python/python_ivalue.h``.
1494

1495
    It extracts the tensors contained in the given object, through pickling.
1496
    """
1497
    tensors: List[torch.Tensor] = []
1498
    extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
1499
    extractor.dump(obj)
1500
    return tensors
1501

1502

1503
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
1504
# that were previously dropped. To preserve the behavior, explicitly drop them there
1505

1506
if sys.version_info > (3, 10):
1507
    _drop(enum.Enum.__new__)
1508
    _drop(enum.Enum.__format__)
1509
    _drop(enum.Enum.__repr__)
1510
    _drop(enum.Enum.__str__)
1511

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.