2
The weak_script annotation needs to be here instead of inside torch/jit/ so it
3
can be used in other places in torch/ (namely torch.nn) without running into
4
circular dependency problems
21
from textwrap import dedent
22
from typing import ( # noqa: F401
29
get_args, # new in 3.8
30
get_origin, # new in 3.8
41
# This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`.
42
# Explicitly ask to import `torch.distributed.__init__` first.
43
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
44
import torch.distributed.rpc
45
import torch.package._mangling as package_mangling
46
from torch._awaits import _Await
47
from torch._C import _Await as CAwait, Future as CFuture
48
from torch._sources import fake_range, get_source_lines_and_file, parse_def
49
from torch.futures import Future
51
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
52
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
54
BuiltinUnionType: Union[Type, Tuple[Type, ...]]
55
if sys.version_info >= (3, 10):
56
# NOTE: IS_PY310_PLUS doesn't work with mypy.
57
# cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks
58
BuiltinUnionType = types.UnionType
60
BuiltinUnionType = () # trick: this makes isinstance short circuit.
66
LockType = _thread.LockType
68
import _dummy_thread # type: ignore[import-not-found]
70
LockType = _dummy_thread.LockType
72
# Wrapper functions that can call either of 2 functions depending on a boolean
74
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
75
weakref.WeakKeyDictionary()
79
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
86
def cache(self, fn, source):
87
self.content[fn] = source
89
def get_source(self, fn):
90
return self.content.get(fn)
93
loader = SourceLoader()
96
def createResolutionCallbackFromEnv(lookup_base):
98
Creates a resolution callback that will look up qualified names in an
99
environment, starting with `lookup_base` for the base of any qualified
100
names, then proceeding down the lookup chain with the resolved object.
102
You should not use this directly, it should only be used from the other
103
createResolutionCallbackFrom* functions.
106
def lookupInModule(qualified_name, module):
107
if "." in qualified_name:
108
base, remaining_pieces = qualified_name.split(".", maxsplit=1)
109
module_value = getattr(module, base)
110
return lookupInModule(remaining_pieces, module_value)
112
return getattr(module, qualified_name)
114
def parseNestedExpr(expr, module) -> Tuple[Any, int]:
116
while i < len(expr) and expr[i] not in (",", "[", "]"):
119
# Special case logic for the empty Tuple as a subscript (used
120
# in the type annotation `Tuple[()]`)
124
base = lookupInModule(expr[:i].strip(), module)
125
assert base is not None, f"Unresolvable type {expr[:i]}"
126
if i == len(expr) or expr[i] != "[":
129
assert expr[i] == "["
131
while expr[i] != "]":
134
part, part_len = parseNestedExpr(expr[i:], module)
138
return base[tuple(parts)], i + 1
140
return base[parts[0]], i + 1
142
def parseExpr(expr, module):
144
value, len_parsed = parseNestedExpr(expr, module)
145
assert len_parsed == len(
147
), "whole expression was not parsed, falling back to c++ parser"
151
The python resolver fails in several cases in known unit tests, and is intended
152
to fall back gracefully to the c++ resolver in general. For example, python 2 style
153
annotations which are frequent in our unit tests often fail with types e.g. int not
154
resolvable from the calling frame.
158
return lambda expr: parseExpr(expr, lookup_base)
161
def createResolutionCallbackFromFrame(frames_up: int = 0):
163
Creates a function which, given a string variable name,
164
returns the value of the variable in the scope of the caller of
165
the function which called createResolutionCallbackFromFrame (by default).
167
This is used to enable access in-scope Python variables inside
168
TorchScript fragments.
170
frames_up is number of additional frames to go up on the stack.
171
The default value is 0, which correspond to the frame of the caller
172
of createResolutionCallbackFromFrame. Also for example, if frames_up is set
173
to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
176
For example, the following program prints 2::
179
cb = createResolutionCallbackFromFrame(1)
188
frame = inspect.currentframe()
190
while i < frames_up + 1:
191
assert frame is not None
195
assert frame is not None
196
f_locals = frame.f_locals
197
f_globals = frame.f_globals
200
def __getattr__(self, key):
203
elif key in f_globals:
204
return f_globals[key]
205
elif key in dir(builtins):
206
return getattr(builtins, key)
208
return createResolutionCallbackFromEnv(env())
213
Get a dictionary of closed over variables from a function
216
captures.update(fn.__globals__)
218
for index, captured_name in enumerate(fn.__code__.co_freevars):
219
captures[captured_name] = fn.__closure__[index].cell_contents
224
# [local resolution in python]
225
# Depending on where a variable is defined, and where it is used, we may
226
# or may not be able to recover its value when recursively compiling a
227
# script function. Remember in the general case, a module or function is
228
# first defined and then later scripted. This means we do not have a
229
# chance to capture the active frames when the function is defined. Hence any
230
# name resolution has to happen later on the created closure. The way
231
# python captures type annotations restricts what we can recover. The
232
# follow example illustrates the different cases:
234
# class MyGlobalClass:
236
# def my_local_scope():
241
# class MyClassUsedAsVar:
243
# def eg(x: MyClass, y: MyGlobalClass):
244
# a_local_capture : Foo
245
# return MyClassUsedAsVar(x)
247
# MyGlobalClass is defined in the __globals__ dictionary of function
248
# 'eg', so it is always recoverable. my_local_scope introduces a new local
249
# variable scope in the function. Classes defined here are only visible as
250
# local variables. For the case of MyClassUsedAsVar, it is captured
251
# because it is used as a variable inside the body of the function, and we
252
# can resolve it using the captures returned from `get_closure`. However,
253
# the type annotations are not captured by the closure. In Python
254
# 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as
255
# annotations on `eg``, but starting in Python 4.0, they will represented as
256
# strings and no longer present. Furthermore, since the body of `eg` does
257
# not reference those names, they do not appear in the list of closed over
258
# variables. In Python 2.x, type annotations are in comments, leading to a
259
# similar situation where their definitions are not available. We anticipate
260
# that most users will not run into this issue because their modules and
261
# functions will be defined at a global scope like MyGlobalClass. In cases
262
# where they are not, it is possible to work around issues by declaring the
263
# values global in the function.
264
# In Python 3.9 declaring class as global will make it invisible to
265
# `inspect.getsource`, see https://bugs.python.org/issue42666 .
266
# This could be worked around by manualy adding it to `global()` dictionary.
269
def createResolutionCallbackFromClosure(fn):
271
Create a resolutionCallback by introspecting the function instead of
272
looking up the stack for the enclosing scope
274
closure = get_closure(fn)
276
class closure_lookup:
277
# This is a class since `closure` is a dict and it's easier in
278
# `env_helper` if everything just works with `getattr` calls
279
def __getattr__(self, key):
282
elif hasattr(typing, key):
283
return getattr(typing, key)
284
elif hasattr(builtins, key):
285
return getattr(builtins, key)
288
return createResolutionCallbackFromEnv(closure_lookup())
291
def can_compile_class(cls) -> bool:
292
# If any of the functions on a type don't have a code object, this type can't
293
# be compiled and is probably a builtin / bound from C
294
if is_ignored_fn(cls):
297
# Ignore the following list of built-in classes.
298
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
299
if issubclass(cls, ignored_builtin_classes):
306
if inspect.isroutine(getattr(cls, name, None))
308
has_code = [hasattr(fn, "__code__") for fn in fns]
312
def get_callable_argument_names(fn) -> List[str]:
314
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
315
Returns an empty list when other types of arguments are present.
317
This is used by `torch.jit.trace` to assign meaningful argument names to
318
traced functions and modules.
323
Argument names: List[str]
325
# inspect.signature may fail, give up in that case.
327
callable_signature = inspect.signature(fn)
332
for name, param in callable_signature.parameters.items():
333
# All four other types of arguments do not map to individual values
334
# with a keyword as name.
335
if not param.kind == param.POSITIONAL_OR_KEYWORD:
338
argument_names.append(name)
340
return argument_names
343
def get_annotation_str(annotation):
345
Convert an AST node containing a type annotation to the string present in the source
346
that represents the same annotation.
348
if isinstance(annotation, ast.Name):
350
elif isinstance(annotation, ast.Attribute):
351
return ".".join([get_annotation_str(annotation.value), annotation.attr])
352
elif isinstance(annotation, ast.Subscript):
353
# In Python3.9+ subscript indicies are not wrapped in ast.Index
354
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined]
355
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
356
elif isinstance(annotation, ast.Tuple):
357
return ",".join([get_annotation_str(elt) for elt in annotation.elts])
358
elif isinstance(annotation, (ast.Constant, ast.NameConstant)):
359
return f"{annotation.value}"
361
# If an AST node is not handled here, it's probably handled in ScriptTypeParser.
365
def get_type_hint_captures(fn):
367
Get a dictionary containing type resolution mappings necessary to resolve types
368
for the literal annotations on 'fn'. These are not considered to be closed-over by fn
369
and must be obtained separately (e.g. using this function).
374
A Dict[str, Any] containing a mapping from the literal annotations used on
375
fn to the Python objects they refer to.
377
# First, try to get the source of the function. We'll need to parse it to find the actual string names
378
# that were used to annotate the types, since inspect.signature() will only return the class object that
379
# the annotation refers to, not the string name. If we can't get the source, simply return an empty dict.
380
# This may happen in cases where the function is synthesized dynamically at runtime.
381
src = loader.get_source(fn)
383
src = inspect.getsource(fn)
385
# Gather a dictionary of parameter name -> type, skipping any parameters whose annotated
386
# types are strings. These are only understood by TorchScript in the context of a type annotation
387
# that refers to a class in its own definition, but trying to include a mapping for this in the result
388
# function would cause infinite recursion because the class is currently being compiled.
389
# In addition, there is logic in ScriptTypeParser to handle this.
390
signature = inspect.signature(fn)
392
name: parameter.annotation
393
for name, parameter in signature.parameters.items()
394
if parameter.annotation is not inspect.Parameter.empty
395
and not isinstance(parameter.annotation, str)
398
# Then, get the literal type annotations from the function declaration
399
# by source inspection. This accounts for the case in which aliases are used
400
# to annotate the arguments (e.g device_t = torch.device, and then d: device_t).
401
# frontend.py cannot be used here because it includes _jit_internal, so use ast instead.
402
a = ast.parse(dedent(src))
403
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
404
raise RuntimeError(f"Expected {fn} to be a function")
407
# Prepare a dictionary of source annotation -> type, which will be the final result of this function,
408
# by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping
409
# them to the type object corresponding to the annotation via name_to_type using the parameter name.
410
annotation_to_type = {}
412
for arg in f.args.args:
413
# Get the source type annotation string for this argument if possible.
414
arg_annotation_str = (
415
get_annotation_str(arg.annotation) if arg.annotation else None
418
# If the argument has no annotation or get_annotation_str cannot convert it to a string,
419
# arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle
420
# this in the latter case.
421
if arg_annotation_str is None:
424
# Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not
425
# be present in name_to_type is that the annotation itself is a string and not a type object
426
# (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this.
428
if arg_name in name_to_type:
429
annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
431
# If there is a valid return annotation, include it in annotation_to_type. As with argument annotations,
432
# the literal annotation has to be convertible to a string by get_annotation_str, and the actual type
433
# of the annotation cannot be a string.
434
literal_return_annotation = get_annotation_str(f.returns)
435
valid_literal_annotation = literal_return_annotation is not None
436
return_annotation = signature.return_annotation
437
valid_return_annotation_type = (
438
return_annotation is not inspect.Parameter.empty
439
and not isinstance(return_annotation, str)
441
if valid_literal_annotation and valid_return_annotation_type:
442
annotation_to_type[literal_return_annotation] = return_annotation
444
return annotation_to_type
447
def createResolutionCallbackForClassMethods(cls):
449
This looks at all the methods defined in a class and pulls their closed-over
450
variables into a dictionary and uses that to resolve variables.
452
# cls is a type here, so `ismethod` is false since the methods on the type
453
# aren't bound to anything, so Python treats them as regular functions
456
for name in cls.__dict__
457
if inspect.isroutine(getattr(cls, name))
459
# Skip built-ins, as they do not have global scope nor type hints
460
# Needed to support `enum.Enum` derived classes in Python-3.11
461
# That adds `_new_member_` property which is an alias to `__new__`
462
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
466
captures.update(get_closure(fn))
467
captures.update(get_type_hint_captures(fn))
469
def lookup_in_class(key):
473
return getattr(builtins, key, None)
475
return lookup_in_class
479
arg_name, arg_index, default, if_true, if_false, module_name, func_name
482
Dispatches to either of 2 script functions based on a boolean argument.
483
In TorchScript, the boolean argument must be constant so that the correct
484
function to use can be determined at compile time.
487
def fn(*args, **kwargs):
488
dispatch_flag = default
489
if arg_name in kwargs:
490
dispatch_flag = kwargs[arg_name]
491
elif arg_index < len(args):
492
dispatch_flag = args[arg_index]
495
return if_true(*args, **kwargs)
497
return if_false(*args, **kwargs)
499
if if_true.__doc__ is None and if_false.__doc__ is not None:
500
doc = if_false.__doc__
501
if_true.__doc__ = doc
502
elif if_false.__doc__ is None and if_true.__doc__ is not None:
503
doc = if_true.__doc__
504
if_false.__doc__ = doc
505
elif if_false.__doc__ is None and if_true.__doc__ is None:
506
# neither function has a docstring
509
raise RuntimeError("only one function can have a docstring")
512
if module_name is not None:
513
fn.__module__ = module_name
514
if func_name is not None:
515
fn.__name__ = func_name
517
boolean_dispatched[fn] = {
519
"if_false": if_false,
522
"arg_name": arg_name,
527
class FunctionModifiers:
529
Used to denote the behavior of a function in TorchScript. See export() and
530
ignore() for details.
533
UNUSED = "unused (ignored and replaced with raising of an exception)"
534
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
535
EXPORT = "export (compile this function even if nothing calls it)"
536
DEFAULT = "default (compile if called from a exported function / forward)"
537
COPY_TO_SCRIPT_WRAPPER = (
538
"if this method is not scripted, copy the python method onto the scripted model"
540
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
545
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
546
:class:`ScriptModule` and should be compiled.
548
``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
549
Functions and methods called from ``forward`` are compiled as they are seen
550
by the compiler, so they do not need this decorator either.
552
Example (using ``@torch.jit.export`` on a method):
557
import torch.nn as nn
559
class MyModule(nn.Module):
560
def implicitly_compiled_method(self, x):
563
# `forward` is implicitly decorated with `@torch.jit.export`,
564
# so adding it here would have no effect
565
def forward(self, x):
569
def another_forward(self, x):
570
# When the compiler sees this call, it will compile
571
# `implicitly_compiled_method`
572
return self.implicitly_compiled_method(x)
574
def unused_method(self, x):
577
# `m` will contain compiled methods:
580
# `implicitly_compiled_method`
581
# `unused_method` will not be compiled since it was not called from
582
# any compiled methods and wasn't decorated with `@torch.jit.export`
583
m = torch.jit.script(MyModule())
585
fn._torchscript_modifier = FunctionModifiers.EXPORT
591
This decorator indicates to the compiler that a function or method should
592
be ignored and replaced with the raising of an exception. This allows you
593
to leave code in your model that is not yet TorchScript compatible and still
596
Example (using ``@torch.jit.unused`` on a method)::
599
import torch.nn as nn
601
class MyModule(nn.Module):
602
def __init__(self, use_memory_efficient):
604
self.use_memory_efficient = use_memory_efficient
607
def memory_efficient(self, x):
612
def forward(self, x):
613
# Use not-yet-scriptable memory efficient mode
614
if self.use_memory_efficient:
615
return self.memory_efficient(x)
619
m = torch.jit.script(MyModule(use_memory_efficient=False))
622
m = torch.jit.script(MyModule(use_memory_efficient=True))
626
if isinstance(fn, property):
628
setattr( # noqa: B010
629
prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
633
setattr( # noqa: B010
634
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
639
fn._torchscript_modifier = FunctionModifiers.UNUSED
643
# No op context manager from python side
644
class _IgnoreContextManager(contextlib.AbstractContextManager):
645
def __init__(self, **kwargs):
648
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
652
def ignore(drop=False, **kwargs):
654
This decorator indicates to the compiler that a function or method should
655
be ignored and left as a Python function. This allows you to leave code in
656
your model that is not yet TorchScript compatible. If called from TorchScript,
657
ignored functions will dispatch the call to the Python interpreter. Models with ignored
658
functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
660
Example (using ``@torch.jit.ignore`` on a method)::
663
import torch.nn as nn
665
class MyModule(nn.Module):
667
def debugger(self, x):
671
def forward(self, x):
673
# The compiler would normally try to compile `debugger`,
674
# but since it is `@ignore`d, it will be left as a call
679
m = torch.jit.script(MyModule())
681
# Error! The call `debugger` cannot be saved since it calls into Python
684
Example (using ``@torch.jit.ignore(drop=True)`` on a method):
689
import torch.nn as nn
691
class MyModule(nn.Module):
692
@torch.jit.ignore(drop=True)
693
def training_method(self, x):
697
def forward(self, x):
699
self.training_method(x)
702
m = torch.jit.script(MyModule())
704
# This is OK since `training_method` is not saved, the call is replaced
715
# used without any args, so drop is actually a function
719
fn._torchscript_modifier = FunctionModifiers.IGNORE
722
if not isinstance(drop, bool):
724
"Argument to @torch.jit.ignore must be a bool or "
725
f"a function but got {drop}"
728
# for backwards compat
729
drop_on_export = kwargs.pop("drop_on_export", None)
732
"ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
733
"call on compilation. Use torch.jit.unused now. {}",
734
category=FutureWarning,
737
drop = drop_on_export
740
"ignore(True) has been deprecated. TorchScript will now drop the function "
741
"call on compilation. Use torch.jit.unused now. {}",
742
category=FutureWarning,
747
fn._torchscript_modifier = FunctionModifiers.UNUSED
749
fn._torchscript_modifier = FunctionModifiers.IGNORE
756
fn._torchscript_modifier = FunctionModifiers._DROP
760
def _copy_to_script_wrapper(fn):
761
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
765
def module_has_exports(mod):
766
for name in dir(mod):
767
if hasattr(mod, name):
768
item = getattr(mod, name)
770
if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
775
# WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you
776
# rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to
777
# allow JIT'd code to still be covered.
778
def should_drop(fn) -> bool:
779
attr = get_torchscript_modifier(fn)
782
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
785
def is_ignored_fn(fn) -> bool:
786
mod = get_torchscript_modifier(fn)
788
mod is FunctionModifiers.UNUSED
789
or mod is FunctionModifiers.IGNORE
790
or mod is FunctionModifiers._DROP
794
def _is_drop_fn(fn) -> bool:
795
mod = get_torchscript_modifier(fn)
796
return mod is FunctionModifiers._DROP
799
def is_static_fn(cls, fn) -> bool:
800
return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
803
def get_static_fn(cls, fn):
804
return inspect.getattr_static(cls, fn).__func__
807
def get_torchscript_modifier(fn):
810
if hasattr(fn, "__func__"):
812
return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
815
def copy_torchscript_modifier(orig, new) -> None:
816
attr = get_torchscript_modifier(orig)
819
new._torchscript_modifier = attr
822
# overloading registration
823
# overloads get registered in this file, and compiled in torch/jit/__init__.py
824
# so that they can be imported in nn/functional.py without an import cycle
826
# qualified_name => list[overload_functions]
827
_overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484
830
_OVERLOAD_EXAMPLE = """
831
Example usage of overload function:
833
def my_function(x: type0) -> type0: # decl 1
837
def my_function(x: type1) -> type1: # decl 2
840
def my_function(x): # implementation
841
if isinstance(x, type0):
843
elif isinstance(x, type1):
848
def get_overload_no_implementation_error_message(kind, obj):
849
sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
851
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
852
f"sure a definition is provided and defined after all overload declarations.\n"
853
f'File "{filename}", line {file_lineno}:\n'
854
+ "".join(sourcelines)
860
def _check_overload_body(func):
862
parsed_def = parse_def(func)
864
# Parsing the function definition can raise an OSError if source is unavailable.
865
# Since this is just an initial check, just raise a warning if this is the case.
867
f"Unable to retrieve source for @torch.jit._overload function: {func}."
871
body = parsed_def.ast.body[0].body
874
return isinstance(x, ast.Pass)
877
return isinstance(x, ast.Expr) and isinstance(x.value, ast.Ellipsis)
879
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
881
"Only `pass` statement or `...` can be the body of overload declaration:\n"
883
msg += "\n".join(parsed_def.source.split("\n")[:3])
884
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
885
raise RuntimeError(msg)
889
_check_overload_body(func)
890
qual_name = _qualified_name(func)
891
global _overloaded_fns
892
fn_overload_list = _overloaded_fns.get(qual_name)
893
if fn_overload_list is None:
894
fn_overload_list = []
895
_overloaded_fns[qual_name] = fn_overload_list
896
fn_overload_list.append(func)
900
def _get_fn_overloads(qual_name):
901
return _overloaded_fns.get(qual_name)
904
def _clear_fn_overloads(qual_name) -> None:
905
del _overloaded_fns[qual_name]
908
def get_class_name_lineno(method) -> Tuple[str, int]:
909
current_frame = inspect.currentframe()
911
# one for the get_class_name call, one for _overload_method call
914
current_frame is not None
915
) # assert current frame is not an Optional[FrameType]
916
current_frame = current_frame.f_back
918
assert current_frame is not None # same here
919
class_name = current_frame.f_code.co_name
920
line_no = current_frame.f_code.co_firstlineno
921
return class_name, line_no
924
# At the point the decorator is applied to class methods the method
925
# has no reference to its owning class. _qualified_name would not include
926
# the class it is defined in, so any methods with the same name in the same file
927
# would have the same _qualified_name, even if they were defined in different
928
# classes. This problem only exists in python 2.
929
# We get around this problem by looking at the stack frame and identifying
930
# the class name, and throwing an error whenever overloads are used
931
# when modules of the same name are in the same file
933
# qualified_name => class name => list[overload_functions]
934
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484
937
# (qualified_name, class name) => class_fileno
938
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
941
def _overload_method(func):
942
_check_overload_body(func)
943
qual_name = _qualified_name(func)
944
global _overloaded_methods
945
class_name_map = _overloaded_methods.get(qual_name, None)
946
if class_name_map is None:
948
_overloaded_methods[qual_name] = class_name_map
950
class_name, line_no = get_class_name_lineno(func)
951
method_overloads = class_name_map.get(class_name, None)
952
if method_overloads is None:
953
method_overloads = []
954
class_name_map[class_name] = method_overloads
955
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no
957
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
958
if existing_lineno != line_no:
960
"Cannot currently overload the same method name in two different"
961
" classes with the same name in the same module"
964
method_overloads.append(func)
968
def _get_overloaded_methods(method, mod_class):
969
# TODO: __name__ not set for submodules in recursive script
970
if not hasattr(method, "__name__"):
972
qual_name = _qualified_name(method)
973
class_name_map = _overloaded_methods.get(qual_name, None)
974
if class_name_map is None:
976
overloads = class_name_map.get(mod_class.__name__, None)
977
if overloads is None:
980
method_line_no = get_source_lines_and_file(method)[1]
981
mod_class_fileno = get_source_lines_and_file(mod_class)[1]
982
mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
983
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
985
"Overloads are not useable when a module is redeclared within the same file: "
991
def is_tuple(ann) -> bool:
993
raise_error_container_parameter_missing("Tuple")
995
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
996
if not hasattr(ann, "__module__"):
999
ann_origin = get_origin(ann)
1000
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
1002
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
1005
def is_list(ann) -> bool:
1007
raise_error_container_parameter_missing("List")
1009
if not hasattr(ann, "__module__"):
1012
ann_origin = get_origin(ann)
1013
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
1015
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
1018
def is_dict(ann) -> bool:
1020
raise_error_container_parameter_missing("Dict")
1022
if not hasattr(ann, "__module__"):
1025
ann_origin = get_origin(ann)
1026
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
1028
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
1033
raise_error_container_parameter_missing("Union")
1035
return isinstance(ann, BuiltinUnionType) or (
1036
hasattr(ann, "__module__")
1037
and ann.__module__ == "typing"
1038
and (get_origin(ann) is Union)
1042
def is_optional(ann):
1044
raise_error_container_parameter_missing("Optional")
1046
def is_optional_as_optional(ann):
1048
hasattr(ann, "__module__")
1049
and ann.__module__ == "typing"
1050
and (get_origin(ann) is Optional)
1053
def is_union_as_optional(ann):
1054
ann_args = get_args(ann)
1055
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
1057
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
1060
def is_future(ann) -> bool:
1063
"Attempted to use Future without a "
1064
"contained type. Please add a contained type, e.g. "
1067
return get_origin(ann) is Future
1070
def is_await(ann) -> bool:
1073
return get_origin(ann) is _Await
1076
if torch.distributed.rpc.is_available():
1077
from torch._C._distributed_rpc import PyRRef
1078
from torch.distributed.rpc import RRef
1080
def is_rref(ann) -> bool:
1083
"Attempted to use RRef without a "
1084
"contained type. Please add a contained type, e.g. "
1087
return get_origin(ann) is RRef
1089
def is_rref_instance(obj) -> bool:
1090
return isinstance(obj, PyRRef)
1094
def is_rref_instance(obj) -> bool:
1095
# If the RPC module doesn't exist then RRefs don't exist either.
1099
def is_final(ann) -> bool:
1101
hasattr(ann, "__module__")
1102
and ann.__module__ in {"typing", "typing_extensions"}
1103
and (get_origin(ann) is Final or isinstance(ann, type(Final)))
1107
# allows BroadcastingList instance to be subscriptable
1108
class BroadcastingListCls:
1109
def __getitem__(self, types):
1113
# mypy doesn't support parameters on types, so we have to explicitly type each
1115
BroadcastingList1 = BroadcastingListCls()
1116
for i in range(2, 7):
1117
globals()[f"BroadcastingList{i}"] = BroadcastingList1
1120
def is_scripting() -> bool:
1122
Function that returns True when in compilation and False otherwise. This
1123
is useful especially with the @unused decorator to leave code in your
1124
model that is not yet TorchScript compatible.
1130
def unsupported_linear_op(x):
1134
if torch.jit.is_scripting():
1135
return torch.linear(x)
1137
return unsupported_linear_op(x)
1142
# Retrieves a fully-qualified name (module hierarchy + classname) for a given obj.
1143
def _qualified_name(obj, mangle_name=True) -> str:
1144
# This special case allows us to override the qualified name on a type.
1145
# It's currently used in conjunction with tracing, where we create a
1146
# fake module to filter only supported attributes. However, since this
1147
# new type is defined as a local class, we need a mechanism to override
1148
# its qualname so it appears correctly in the TorchScript system. This,
1149
# we set '_jit_override_qualname' with the original traced module's
1150
# qualified name, which is picked up here
1151
if hasattr(obj, "_jit_override_qualname"):
1152
return obj._jit_override_qualname
1153
# short-circuit in cases where the object already has a known qualified name
1154
if isinstance(obj, torch._C.ScriptFunction):
1155
return obj.qualified_name
1157
if getattr(obj, "__name__", None):
1159
# Enum classes do not have `__name__` attr, instead they have `name`.
1160
elif isinstance(obj, enum.Enum):
1163
raise RuntimeError("Could not get name of python class object")
1165
if name == "<lambda>":
1166
name = "_lambda" # make name a valid identifier
1168
module_name = obj.__module__
1170
# If the module is actually a torchbind module, then we should short circuit
1171
if module_name == "torch._classes":
1172
return obj.qualified_name
1174
# The Python docs are very clear that `__module__` can be None, but I can't
1175
# figure out when it actually would be.
1176
if module_name is None:
1178
f"Could not get qualified name for class '{name}': "
1179
"__module__ can't be None."
1182
# if getattr(sys.modules[module_name], name) is not obj:
1183
# raise RuntimeError(f"Could not get qualified name for class '{name}': "
1184
# f"the attr {name} on module {module_name} is not the class")
1186
# torch.package and TorchScript have separate mangling schemes to avoid
1187
# name collisions from multiple packages. To avoid them interfering with
1188
# each other, normalize the package manging here.
1189
if package_mangling.is_mangled(module_name):
1190
module_name = module_name.replace("<", "_")
1191
module_name = module_name.replace(">", "_")
1193
# The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h
1194
# does not need mangle the python class name.
1196
# __main__ is a builtin module, so rewrite it to "__torch__".
1197
if module_name == "__main__":
1198
module_name = "__torch__"
1200
# Everything else gets a "__torch__" prefix to avoid name collisions
1201
# with the names of user values.
1202
module_name = "__torch__." + module_name
1206
f"Could not get qualified name for class '{name}': "
1207
f"'{name}' is not a valid identifier"
1210
return module_name + "." + name
1213
def _try_get_dispatched_fn(fn):
1214
if not callable(fn):
1216
return boolean_dispatched.get(fn)
1219
def _get_named_tuple_properties(
1220
obj, loc: Optional[torch._C._jit_tree_views.SourceRange] = None, rcb=None
1225
assert issubclass(obj, tuple) and hasattr(obj, "_fields")
1226
if hasattr(obj, "_field_defaults"):
1228
obj._field_defaults[field]
1229
for field in obj._fields
1230
if field in obj._field_defaults
1234
# In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function
1235
# Also, annotations from base class are not inherited so they need to be queried explicitly
1236
if sys.version_info[:2] < (3, 10):
1237
obj_annotations = getattr(obj, "__annotations__", {})
1239
obj_annotations = inspect.get_annotations(obj)
1240
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
1241
obj_annotations = inspect.get_annotations(obj.__base__)
1244
for field in obj._fields:
1245
if field in obj_annotations:
1246
field_type = obj_annotations[field]
1247
# [Note: ForwardRef annotations in NamedTuple attributes]
1248
# NamedTuple types are slightly different from normal types.
1250
# Normally, annotations are evaluted like this (during jit.script):
1251
# 1. Load strings of python code into c++ and parse.
1252
# 2. Get annotations as strings
1253
# 3. Use the PythonResolver's resolution callback (rcb) to convert
1254
# the string into a python object
1255
# 4. We call into annotations.py:ann_to_type to convert python obj
1256
# from step 3 into a type that torchscript understands.
1258
# NamedTuples are more complicated, because it has sub-types.
1259
# Normally, once we have the NamedTuple type object from #3,
1260
# we can just look at the annotation literal values and use
1261
# ann_to_type directly on them.
1263
# But sometimes, users will annotate with string literals, e.g.
1265
# This also happens with PEP563 (from __forward__ import annotations)
1267
# These annotations appear in the annotation dict as ForwardRef('int').
1269
# Then, we need to convert the string into a python object. This
1270
# requires having local context for custom objects or imported types.
1271
# rcb() is what gives us this. So, we plumb rcb through the stack so
1272
# it can be used in this context for the if block below.
1275
# - Why do we need this special handling for NamedTuple but string
1276
# annotations work fine for normal types? Normally, we parse the
1277
# string directly and then call rcb() directly from C++.
1278
# - Why not use ForwardRef._evaluate? For that, we need globals()
1279
# and locals() for the local context where the NamedTuple was defined.
1280
# rcb is what lets us look up into these. So, basically rcb does the
1282
if isinstance(field_type, ForwardRef) and rcb is not None:
1283
rcb_type = rcb(field_type.__forward_arg__)
1284
# rcb returns None if it can't find anything.
1285
if rcb_type is None:
1287
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
1288
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
1289
f" Issue occurred at {loc.highlight()}"
1291
field_type = rcb_type
1292
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
1293
annotations.append(the_type)
1295
annotations.append(torch._C.TensorType.getInferred())
1296
return type(obj).__name__, obj._fields, annotations, defaults
1299
def _create_named_tuple(
1300
t, unqual_name: str, field_names: List[str], defaults: Tuple[Any, ...]
1302
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc]
1303
return TupleType(*t)
1306
@contextlib.contextmanager
1307
def _disable_emit_hooks():
1308
hooks = torch._C._jit_get_emit_hooks()
1309
torch._C._jit_set_emit_hooks(None, None)
1313
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
1316
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811
1317
def __enter__(self) -> None:
1318
self.hooks = torch._C._jit_get_emit_hooks()
1319
torch._C._jit_set_emit_hooks(None, None)
1321
def __exit__(self, *args) -> None:
1322
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
1325
def _is_exception(obj) -> bool:
1326
if not inspect.isclass(obj):
1328
return issubclass(obj, Exception)
1331
def raise_error_container_parameter_missing(target_type) -> None:
1332
if target_type == "Dict":
1334
"Attempted to use Dict without "
1335
"contained types. Please add contained type, e.g. "
1339
f"Attempted to use {target_type} without a "
1340
"contained type. Please add a contained type, e.g. "
1341
f"{target_type}[int]"
1345
def check_args_exist(target_type) -> None:
1346
if target_type is List or target_type is list:
1347
raise_error_container_parameter_missing("List")
1348
elif target_type is Tuple or target_type is tuple:
1349
raise_error_container_parameter_missing("Tuple")
1350
elif target_type is Dict or target_type is dict:
1351
raise_error_container_parameter_missing("Dict")
1352
elif target_type is None or target_type is Optional:
1353
raise_error_container_parameter_missing("Optional")
1356
def check_empty_containers(obj) -> None:
1357
if obj == [] or obj == {} or obj == ():
1359
"The inner type of a container is lost when "
1360
"calling torch.jit.isinstance in eager mode. For "
1361
"example, List[int] would become list and "
1362
"therefore falsely return True for List[float] or"
1367
# supports List/Dict/Tuple and Optional types
1368
# TODO support future
1369
def container_checker(obj, target_type) -> bool:
1370
origin_type = get_origin(target_type)
1371
check_args_exist(target_type)
1372
if origin_type is None:
1374
elif origin_type is list or origin_type is List:
1375
check_empty_containers(obj)
1376
if not isinstance(obj, list):
1378
arg_type = get_args(target_type)[0]
1379
arg_origin = get_origin(arg_type)
1381
# check if nested container, ex: List[List[str]]
1382
if arg_origin: # processes nested container, ex: List[List[str]]
1383
if not container_checker(el, arg_type):
1385
elif not isinstance(el, arg_type):
1388
elif origin_type is Dict or origin_type is dict:
1389
check_empty_containers(obj)
1390
if not isinstance(obj, dict):
1392
key_type = get_args(target_type)[0]
1393
val_type = get_args(target_type)[1]
1394
for key, val in obj.items():
1395
# check if keys are of right type
1396
if not isinstance(key, key_type):
1398
val_origin = get_origin(val_type)
1400
if not container_checker(val, val_type):
1402
elif not isinstance(val, val_type):
1405
elif origin_type is Tuple or origin_type is tuple:
1406
check_empty_containers(obj)
1407
if not isinstance(obj, tuple):
1409
arg_types = get_args(target_type)
1410
if len(obj) != len(arg_types):
1412
for el, el_type in zip(obj, arg_types):
1413
el_origin = get_origin(el_type)
1415
if not container_checker(el, el_type):
1417
elif not isinstance(el, el_type):
1420
elif origin_type is Union or issubclass(
1421
origin_type, BuiltinUnionType
1422
): # also handles Optional
1423
if obj is None: # check before recursion because None is always fine
1425
inner_types = get_args(target_type)
1426
for t in inner_types:
1427
t_origin = get_origin(t)
1429
return container_checker(obj, t)
1430
elif isinstance(obj, t):
1435
def _isinstance(obj, target_type) -> bool:
1436
if isinstance(target_type, collections.abc.Container):
1437
if not isinstance(target_type, tuple):
1439
"The second argument to "
1440
"`torch.jit.isinstance` must be a type "
1441
"or a tuple of types"
1443
for t_type in target_type:
1444
if _isinstance(obj, t_type):
1448
origin_type = get_origin(target_type)
1450
return container_checker(obj, target_type)
1452
# Check to handle non-typed optional origin returns as none instead
1453
# of as optional in 3.7-3.8
1454
check_args_exist(target_type)
1456
# handle non-containers
1457
return isinstance(obj, target_type)
1460
class _TensorExtractor(pickle.Pickler):
1461
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
1462
super().__init__(*args, **kwargs)
1463
self.tensors = tensors
1465
def persistent_id(self, obj):
1466
if isinstance(obj, torch.Tensor):
1467
self.tensors.append(obj)
1469
# Since we just want to extract tensors, we don't mind if an object is
1470
# unpicklable if it doesn't contain tensors, as we can just ignore/skip
1471
# it. To play it safe, we only do so for common objects that we're sure
1472
# don't contain tensors. Feel free to add new types here. Note also that
1473
# even if a type isn't listed here this won't block users, since thet
1474
# can just add a __getstate__ or __reduce__ method to their class.
1475
if isinstance(obj, LockType):
1477
# Futures and RRefs don't technically contain a value, they just offer
1478
# the means to access a value.
1479
if isinstance(obj, CFuture) or is_rref_instance(obj):
1481
if isinstance(obj, CAwait):
1483
if isinstance(obj, torch.cuda.Event):
1485
if isinstance(obj, threading.Thread):
1490
def _extract_tensors(obj):
1492
This function is exclusively called from C++.
1493
See ``torch/csrc/jit/python/python_ivalue.h``.
1495
It extracts the tensors contained in the given object, through pickling.
1497
tensors: List[torch.Tensor] = []
1498
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
1503
# In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass
1504
# that were previously dropped. To preserve the behavior, explicitly drop them there
1506
if sys.version_info > (3, 10):
1507
_drop(enum.Enum.__new__)
1508
_drop(enum.Enum.__format__)
1509
_drop(enum.Enum.__repr__)
1510
_drop(enum.Enum.__str__)