3
The weak_script annotation needs to be here instead of inside torch/jit/ so it
4
can be used in other places in torch/ (namely torch.nn) without running into
5
circular dependency problems
43
import torch.distributed.rpc
44
import torch.package._mangling as package_mangling
45
from torch._awaits import _Await
46
from torch._C import _Await as CAwait, Future as CFuture
47
from torch._sources import fake_range, get_source_lines_and_file, parse_def
48
from torch.futures import Future
51
IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9)
52
IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10)
54
BuiltinUnionType: Union[Type, Tuple[Type, ...]]
55
if sys.version_info >= (3, 10):
58
BuiltinUnionType = types.UnionType
66
LockType = _thread.LockType
70
LockType = _dummy_thread.LockType
74
boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = (
75
weakref.WeakKeyDictionary()
79
FAKE_FILENAME_PREFIX = "__torch_jit_dataclass"
82
def is_final(ann) -> bool:
84
hasattr(ann, "__module__")
85
and ann.__module__ in {"typing", "typing_extensions"}
86
and (get_origin(ann) is Final or isinstance(ann, type(Final)))
91
class BroadcastingListCls:
92
def __getitem__(self, types):
98
BroadcastingList1 = BroadcastingListCls()
100
globals()[f"BroadcastingList{i}"] = BroadcastingList1
103
def is_scripting() -> bool:
105
Function that returns True when in compilation and False otherwise. This
106
is useful especially with the @unused decorator to leave code in your
107
model that is not yet TorchScript compatible.
113
def unsupported_linear_op(x):
117
if torch.jit.is_scripting():
118
return torch.linear(x)
120
return unsupported_linear_op(x)
126
def _qualified_name(obj, mangle_name=True) -> str:
134
if hasattr(obj, "_jit_override_qualname"):
135
return obj._jit_override_qualname
137
if isinstance(obj, torch._C.ScriptFunction):
138
return obj.qualified_name
140
if getattr(obj, "__name__", None):
143
elif isinstance(obj, enum.Enum):
146
raise RuntimeError("Could not get name of python class object")
148
if name == "<lambda>":
151
module_name = obj.__module__
154
if module_name == "torch._classes":
155
return obj.qualified_name
159
if module_name is None:
161
f"Could not get qualified name for class '{name}': "
162
"__module__ can't be None."
172
if package_mangling.is_mangled(module_name):
173
module_name = module_name.replace("<", "_")
174
module_name = module_name.replace(">", "_")
180
if module_name == "__main__":
181
module_name = "__torch__"
185
module_name = "__torch__." + module_name
189
f"Could not get qualified name for class '{name}': "
190
f"'{name}' is not a valid identifier"
193
return module_name + "." + name
200
def cache(self, fn, source):
201
self.content[fn] = source
203
def get_source(self, fn):
204
return self.content.get(fn)
207
loader = SourceLoader()
210
def createResolutionCallbackFromEnv(lookup_base):
212
Creates a resolution callback that will look up qualified names in an
213
environment, starting with `lookup_base` for the base of any qualified
214
names, then proceeding down the lookup chain with the resolved object.
216
You should not use this directly, it should only be used from the other
217
createResolutionCallbackFrom* functions.
220
def lookupInModule(qualified_name, module):
221
if "." in qualified_name:
222
base, remaining_pieces = qualified_name.split(".", maxsplit=1)
223
module_value = getattr(module, base)
224
return lookupInModule(remaining_pieces, module_value)
226
return getattr(module, qualified_name)
228
def parseNestedExpr(expr, module) -> Tuple[Any, int]:
230
while i < len(expr) and expr[i] not in (",", "[", "]"):
238
base = lookupInModule(expr[:i].strip(), module)
239
assert base is not None, f"Unresolvable type {expr[:i]}"
240
if i == len(expr) or expr[i] != "[":
243
assert expr[i] == "["
245
while expr[i] != "]":
248
part, part_len = parseNestedExpr(expr[i:], module)
252
return base[tuple(parts)], i + 1
254
return base[parts[0]], i + 1
256
def parseExpr(expr, module):
258
value, len_parsed = parseNestedExpr(expr, module)
259
assert len_parsed == len(
261
), "whole expression was not parsed, falling back to c++ parser"
265
The python resolver fails in several cases in known unit tests, and is intended
266
to fall back gracefully to the c++ resolver in general. For example, python 2 style
267
annotations which are frequent in our unit tests often fail with types e.g. int not
268
resolvable from the calling frame.
272
return lambda expr: parseExpr(expr, lookup_base)
275
def createResolutionCallbackFromFrame(frames_up: int = 0):
277
Creates a function which, given a string variable name,
278
returns the value of the variable in the scope of the caller of
279
the function which called createResolutionCallbackFromFrame (by default).
281
This is used to enable access in-scope Python variables inside
282
TorchScript fragments.
284
frames_up is number of additional frames to go up on the stack.
285
The default value is 0, which correspond to the frame of the caller
286
of createResolutionCallbackFromFrame. Also for example, if frames_up is set
287
to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame
290
For example, the following program prints 2::
293
cb = createResolutionCallbackFromFrame(1)
304
frame = inspect.currentframe()
306
while i < frames_up + 1:
307
assert frame is not None
311
assert frame is not None
312
f_locals = frame.f_locals
313
f_globals = frame.f_globals
316
def __getattr__(self, key):
319
elif key in f_globals:
320
return f_globals[key]
321
elif key in dir(builtins):
322
return getattr(builtins, key)
324
return createResolutionCallbackFromEnv(env())
329
Get a dictionary of closed over variables from a function
332
captures.update(fn.__globals__)
334
for index, captured_name in enumerate(fn.__code__.co_freevars):
335
captures[captured_name] = fn.__closure__[index].cell_contents
385
def createResolutionCallbackFromClosure(fn):
387
Create a resolutionCallback by introspecting the function instead of
388
looking up the stack for the enclosing scope
390
closure = get_closure(fn)
392
class closure_lookup:
395
def __getattr__(self, key):
398
elif hasattr(typing, key):
399
return getattr(typing, key)
400
elif hasattr(builtins, key):
401
return getattr(builtins, key)
404
return createResolutionCallbackFromEnv(closure_lookup())
407
def can_compile_class(cls) -> bool:
410
if is_ignored_fn(cls):
414
ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception)
415
if issubclass(cls, ignored_builtin_classes):
422
if inspect.isroutine(getattr(cls, name, None))
424
has_code = [hasattr(fn, "__code__") for fn in fns]
428
def get_callable_argument_names(fn) -> List[str]:
430
Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`.
431
Returns an empty list when other types of arguments are present.
433
This is used by `torch.jit.trace` to assign meaningful argument names to
434
traced functions and modules.
439
Argument names: List[str]
443
callable_signature = inspect.signature(fn)
448
for name, param in callable_signature.parameters.items():
451
if not param.kind == param.POSITIONAL_OR_KEYWORD:
454
argument_names.append(name)
456
return argument_names
459
def get_annotation_str(annotation):
461
Convert an AST node containing a type annotation to the string present in the source
462
that represents the same annotation.
464
if isinstance(annotation, ast.Name):
466
elif isinstance(annotation, ast.Attribute):
467
return ".".join([get_annotation_str(annotation.value), annotation.attr])
468
elif isinstance(annotation, ast.Subscript):
470
subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value
471
return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]"
472
elif isinstance(annotation, ast.Tuple):
473
return ",".join([get_annotation_str(elt) for elt in annotation.elts])
474
elif isinstance(annotation, ast.Constant):
475
return f"{annotation.value}"
481
def get_type_hint_captures(fn):
483
Get a dictionary containing type resolution mappings necessary to resolve types
484
for the literal annotations on 'fn'. These are not considered to be closed-over by fn
485
and must be obtained separately (e.g. using this function).
490
A Dict[str, Any] containing a mapping from the literal annotations used on
491
fn to the Python objects they refer to.
497
src = loader.get_source(fn)
500
src = inspect.getsource(fn)
503
f"Failed to get source for {fn} using inspect.getsource"
511
signature = inspect.signature(fn)
513
name: parameter.annotation
514
for name, parameter in signature.parameters.items()
515
if parameter.annotation is not inspect.Parameter.empty
516
and not isinstance(parameter.annotation, str)
523
a = ast.parse(textwrap.dedent(src))
524
if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef):
525
raise RuntimeError(f"Expected {fn} to be a function")
531
annotation_to_type = {}
533
for arg in f.args.args:
535
arg_annotation_str = (
536
get_annotation_str(arg.annotation) if arg.annotation else None
542
if arg_annotation_str is None:
549
if arg_name in name_to_type:
550
annotation_to_type[arg_annotation_str] = name_to_type[arg_name]
555
literal_return_annotation = get_annotation_str(f.returns)
556
valid_literal_annotation = literal_return_annotation is not None
557
return_annotation = signature.return_annotation
558
valid_return_annotation_type = (
559
return_annotation is not inspect.Parameter.empty
560
and not isinstance(return_annotation, str)
562
if valid_literal_annotation and valid_return_annotation_type:
563
annotation_to_type[literal_return_annotation] = return_annotation
565
return annotation_to_type
568
def createResolutionCallbackForClassMethods(cls):
570
This looks at all the methods defined in a class and pulls their closed-over
571
variables into a dictionary and uses that to resolve variables.
577
for name in cls.__dict__
578
if inspect.isroutine(getattr(cls, name))
583
fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")]
587
captures.update(get_closure(fn))
588
captures.update(get_type_hint_captures(fn))
590
def lookup_in_class(key):
594
return getattr(builtins, key, None)
596
return lookup_in_class
609
Dispatches to either of 2 script functions based on a boolean argument.
610
In TorchScript, the boolean argument must be constant so that the correct
611
function to use can be determined at compile time.
614
def fn(*args, **kwargs):
615
dispatch_flag = default
616
if arg_name in kwargs:
617
dispatch_flag = kwargs[arg_name]
618
elif arg_index < len(args):
619
dispatch_flag = args[arg_index]
622
return if_true(*args, **kwargs)
624
return if_false(*args, **kwargs)
626
if if_true.__doc__ is None and if_false.__doc__ is not None:
627
doc = if_false.__doc__
628
if_true.__doc__ = doc
629
elif if_false.__doc__ is None and if_true.__doc__ is not None:
630
doc = if_true.__doc__
631
if_false.__doc__ = doc
632
elif if_false.__doc__ is None and if_true.__doc__ is None:
636
raise RuntimeError("only one function can have a docstring")
639
if module_name is not None:
640
fn.__module__ = module_name
641
if func_name is not None:
642
fn.__name__ = func_name
644
boolean_dispatched[fn] = {
646
"if_false": if_false,
649
"arg_name": arg_name,
654
class FunctionModifiers:
656
Used to denote the behavior of a function in TorchScript. See export() and
657
ignore() for details.
660
UNUSED = "unused (ignored and replaced with raising of an exception)"
661
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
662
EXPORT = "export (compile this function even if nothing calls it)"
663
DEFAULT = "default (compile if called from a exported function / forward)"
664
COPY_TO_SCRIPT_WRAPPER = (
665
"if this method is not scripted, copy the python method onto the scripted model"
667
_DROP = "_drop (function is fully ignored, declaration can be unscriptable)"
672
This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a
673
:class:`ScriptModule` and should be compiled.
675
``forward`` implicitly is assumed to be an entry point, so it does not need this decorator.
676
Functions and methods called from ``forward`` are compiled as they are seen
677
by the compiler, so they do not need this decorator either.
679
Example (using ``@torch.jit.export`` on a method):
684
import torch.nn as nn
686
class MyModule(nn.Module):
687
def implicitly_compiled_method(self, x):
690
# `forward` is implicitly decorated with `@torch.jit.export`,
691
# so adding it here would have no effect
692
def forward(self, x):
696
def another_forward(self, x):
697
# When the compiler sees this call, it will compile
698
# `implicitly_compiled_method`
699
return self.implicitly_compiled_method(x)
701
def unused_method(self, x):
704
# `m` will contain compiled methods:
707
# `implicitly_compiled_method`
708
# `unused_method` will not be compiled since it was not called from
709
# any compiled methods and wasn't decorated with `@torch.jit.export`
710
m = torch.jit.script(MyModule())
712
fn._torchscript_modifier = FunctionModifiers.EXPORT
718
This decorator indicates to the compiler that a function or method should
719
be ignored and replaced with the raising of an exception. This allows you
720
to leave code in your model that is not yet TorchScript compatible and still
723
Example (using ``@torch.jit.unused`` on a method)::
726
import torch.nn as nn
729
class MyModule(nn.Module):
730
def __init__(self, use_memory_efficient):
732
self.use_memory_efficient = use_memory_efficient
735
def memory_efficient(self, x):
741
def forward(self, x):
742
# Use not-yet-scriptable memory efficient mode
743
if self.use_memory_efficient:
744
return self.memory_efficient(x)
749
m = torch.jit.script(MyModule(use_memory_efficient=False))
752
m = torch.jit.script(MyModule(use_memory_efficient=True))
756
if isinstance(fn, property):
759
prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED
764
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
769
fn._torchscript_modifier = FunctionModifiers.UNUSED
774
class _IgnoreContextManager(contextlib.AbstractContextManager):
775
def __init__(self, **kwargs):
778
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
782
def ignore(drop=False, **kwargs):
784
This decorator indicates to the compiler that a function or method should
785
be ignored and left as a Python function. This allows you to leave code in
786
your model that is not yet TorchScript compatible. If called from TorchScript,
787
ignored functions will dispatch the call to the Python interpreter. Models with ignored
788
functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead.
790
Example (using ``@torch.jit.ignore`` on a method)::
793
import torch.nn as nn
796
class MyModule(nn.Module):
798
def debugger(self, x):
803
def forward(self, x):
805
# The compiler would normally try to compile `debugger`,
806
# but since it is `@ignore`d, it will be left as a call
812
m = torch.jit.script(MyModule())
814
# Error! The call `debugger` cannot be saved since it calls into Python
817
Example (using ``@torch.jit.ignore(drop=True)`` on a method):
822
import torch.nn as nn
824
class MyModule(nn.Module):
825
@torch.jit.ignore(drop=True)
826
def training_method(self, x):
830
def forward(self, x):
832
self.training_method(x)
835
m = torch.jit.script(MyModule())
837
# This is OK since `training_method` is not saved, the call is replaced
852
fn._torchscript_modifier = FunctionModifiers.IGNORE
855
if not isinstance(drop, bool):
857
"Argument to @torch.jit.ignore must be a bool or "
858
f"a function but got {drop}"
862
drop_on_export = kwargs.pop("drop_on_export", None)
865
"ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function "
866
"call on compilation. Use torch.jit.unused now. {}",
867
category=FutureWarning,
870
drop = drop_on_export
873
"ignore(True) has been deprecated. TorchScript will now drop the function "
874
"call on compilation. Use torch.jit.unused now. {}",
875
category=FutureWarning,
880
fn._torchscript_modifier = FunctionModifiers.UNUSED
882
fn._torchscript_modifier = FunctionModifiers.IGNORE
889
fn._torchscript_modifier = FunctionModifiers._DROP
893
def _copy_to_script_wrapper(fn):
894
fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER
898
def module_has_exports(mod):
899
for name in dir(mod):
900
if hasattr(mod, name):
901
item = getattr(mod, name)
903
if get_torchscript_modifier(item) is FunctionModifiers.EXPORT:
911
def should_drop(fn) -> bool:
912
attr = get_torchscript_modifier(fn)
915
return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP
918
def is_ignored_fn(fn) -> bool:
919
mod = get_torchscript_modifier(fn)
921
mod is FunctionModifiers.UNUSED
922
or mod is FunctionModifiers.IGNORE
923
or mod is FunctionModifiers._DROP
927
def _is_drop_fn(fn) -> bool:
928
mod = get_torchscript_modifier(fn)
929
return mod is FunctionModifiers._DROP
932
def is_static_fn(cls, fn) -> bool:
933
return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod)
936
def get_static_fn(cls, fn):
937
return inspect.getattr_static(cls, fn).__func__
940
def get_torchscript_modifier(fn):
943
if hasattr(fn, "__func__"):
945
return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT)
948
def copy_torchscript_modifier(orig, new) -> None:
949
attr = get_torchscript_modifier(orig)
952
new._torchscript_modifier = attr
960
_overloaded_fns: Dict[str, List[Callable]] = {}
963
_OVERLOAD_EXAMPLE = """
964
Example usage of overload function:
966
def my_function(x: type0) -> type0: # decl 1
970
def my_function(x: type1) -> type1: # decl 2
973
def my_function(x): # implementation
974
if isinstance(x, type0):
976
elif isinstance(x, type1):
981
def get_overload_no_implementation_error_message(kind, obj):
982
sourcelines, file_lineno, filename = get_source_lines_and_file(obj)
984
f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make '
985
f"sure a definition is provided and defined after all overload declarations.\n"
986
f'File "{filename}", line {file_lineno}:\n'
987
+ "".join(sourcelines)
993
def _check_overload_body(func):
995
parsed_def = parse_def(func)
1000
f"Unable to retrieve source for @torch.jit._overload function: {func}."
1004
body = parsed_def.ast.body[0].body
1007
return isinstance(x, ast.Pass)
1011
isinstance(x, ast.Expr)
1012
and isinstance(x.value, ast.Constant)
1013
and x.value.value is Ellipsis
1016
if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])):
1018
"Only `pass` statement or `...` can be the body of overload declaration:\n"
1020
msg += "\n".join(parsed_def.source.split("\n")[:3])
1021
msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE
1022
raise RuntimeError(msg)
1026
_check_overload_body(func)
1027
qual_name = _qualified_name(func)
1028
global _overloaded_fns
1029
fn_overload_list = _overloaded_fns.get(qual_name)
1030
if fn_overload_list is None:
1031
fn_overload_list = []
1032
_overloaded_fns[qual_name] = fn_overload_list
1033
fn_overload_list.append(func)
1037
def _get_fn_overloads(qual_name):
1038
return _overloaded_fns.get(qual_name)
1041
def _clear_fn_overloads(qual_name) -> None:
1042
del _overloaded_fns[qual_name]
1045
def get_class_name_lineno(method) -> Tuple[str, int]:
1046
current_frame = inspect.currentframe()
1051
current_frame is not None
1053
current_frame = current_frame.f_back
1055
assert current_frame is not None
1056
class_name = current_frame.f_code.co_name
1057
line_no = current_frame.f_code.co_firstlineno
1058
return class_name, line_no
1071
_overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {}
1075
_overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {}
1078
def _overload_method(func):
1079
_check_overload_body(func)
1080
qual_name = _qualified_name(func)
1081
global _overloaded_methods
1082
class_name_map = _overloaded_methods.get(qual_name, None)
1083
if class_name_map is None:
1085
_overloaded_methods[qual_name] = class_name_map
1087
class_name, line_no = get_class_name_lineno(func)
1088
method_overloads = class_name_map.get(class_name, None)
1089
if method_overloads is None:
1090
method_overloads = []
1091
class_name_map[class_name] = method_overloads
1092
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no
1094
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
1095
if existing_lineno != line_no:
1097
"Cannot currently overload the same method name in two different"
1098
" classes with the same name in the same module"
1101
method_overloads.append(func)
1105
def _get_overloaded_methods(method, mod_class):
1107
if not hasattr(method, "__name__"):
1109
qual_name = _qualified_name(method)
1110
class_name_map = _overloaded_methods.get(qual_name, None)
1111
if class_name_map is None:
1113
overloads = class_name_map.get(mod_class.__name__, None)
1114
if overloads is None:
1117
method_line_no = get_source_lines_and_file(method)[1]
1118
mod_class_fileno = get_source_lines_and_file(mod_class)[1]
1119
mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0])
1120
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
1121
raise AssertionError(
1122
"Overloads are not useable when a module is redeclared within the same file: "
1128
def is_tuple(ann) -> bool:
1130
raise_error_container_parameter_missing("Tuple")
1133
if not hasattr(ann, "__module__"):
1136
ann_origin = get_origin(ann)
1137
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple:
1139
return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple)
1142
def is_list(ann) -> bool:
1144
raise_error_container_parameter_missing("List")
1146
if not hasattr(ann, "__module__"):
1149
ann_origin = get_origin(ann)
1150
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list:
1152
return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list)
1155
def is_dict(ann) -> bool:
1157
raise_error_container_parameter_missing("Dict")
1159
if not hasattr(ann, "__module__"):
1162
ann_origin = get_origin(ann)
1163
if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict:
1165
return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict)
1170
raise_error_container_parameter_missing("Union")
1172
return isinstance(ann, BuiltinUnionType) or (
1173
hasattr(ann, "__module__")
1174
and ann.__module__ == "typing"
1175
and (get_origin(ann) is Union)
1179
def is_optional(ann):
1181
raise_error_container_parameter_missing("Optional")
1183
def is_optional_as_optional(ann):
1185
hasattr(ann, "__module__")
1186
and ann.__module__ == "typing"
1187
and (get_origin(ann) is Optional)
1190
def is_union_as_optional(ann):
1191
ann_args = get_args(ann)
1192
return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args)
1194
return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann))
1197
def is_future(ann) -> bool:
1200
"Attempted to use Future without a "
1201
"contained type. Please add a contained type, e.g. "
1204
return get_origin(ann) is Future
1207
def is_await(ann) -> bool:
1210
return get_origin(ann) is _Await
1213
if torch.distributed.rpc.is_available():
1214
from torch._C._distributed_rpc import PyRRef
1215
from torch.distributed.rpc import RRef
1217
def is_rref(ann) -> bool:
1220
"Attempted to use RRef without a "
1221
"contained type. Please add a contained type, e.g. "
1224
return get_origin(ann) is RRef
1226
def is_rref_instance(obj) -> bool:
1227
return isinstance(obj, PyRRef)
1231
def is_rref_instance(obj) -> bool:
1236
def _try_get_dispatched_fn(fn):
1237
if not callable(fn):
1239
return boolean_dispatched.get(fn)
1242
def _get_named_tuple_properties(
1244
loc: Optional[torch._C._jit_tree_views.SourceRange] = None,
1250
assert issubclass(obj, tuple) and hasattr(obj, "_fields")
1251
if hasattr(obj, "_field_defaults"):
1253
obj._field_defaults[field]
1254
for field in obj._fields
1255
if field in obj._field_defaults
1261
if sys.version_info[:2] < (3, 10):
1262
obj_annotations = getattr(obj, "__annotations__", {})
1264
obj_annotations = inspect.get_annotations(obj)
1265
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
1266
obj_annotations = inspect.get_annotations(obj.__base__)
1269
for field in obj._fields:
1270
if field in obj_annotations:
1271
field_type = obj_annotations[field]
1307
if isinstance(field_type, ForwardRef) and rcb is not None:
1308
rcb_type = rcb(field_type.__forward_arg__)
1310
if rcb_type is None:
1312
f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}."
1313
f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858."
1314
f" Issue occurred at {loc.highlight()}"
1316
field_type = rcb_type
1317
the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb)
1318
annotations.append(the_type)
1320
annotations.append(torch._C.TensorType.getInferred())
1321
return type(obj).__name__, obj._fields, annotations, defaults
1324
def _create_named_tuple(
1327
field_names: List[str],
1328
defaults: Tuple[Any, ...],
1330
TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults)
1331
return TupleType(*t)
1334
@contextlib.contextmanager
1335
def _disable_emit_hooks():
1336
hooks = torch._C._jit_get_emit_hooks()
1337
torch._C._jit_set_emit_hooks(None, None)
1341
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
1344
def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None:
1345
def __enter__(self) -> None:
1346
self.hooks = torch._C._jit_get_emit_hooks()
1347
torch._C._jit_set_emit_hooks(None, None)
1349
def __exit__(self, *args) -> None:
1350
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
1353
def _is_exception(obj) -> bool:
1354
if not inspect.isclass(obj):
1356
return issubclass(obj, Exception)
1359
def raise_error_container_parameter_missing(target_type) -> None:
1360
if target_type == "Dict":
1362
"Attempted to use Dict without "
1363
"contained types. Please add contained type, e.g. "
1367
f"Attempted to use {target_type} without a "
1368
"contained type. Please add a contained type, e.g. "
1369
f"{target_type}[int]"
1373
def check_args_exist(target_type) -> None:
1374
if target_type is List or target_type is list:
1375
raise_error_container_parameter_missing("List")
1376
elif target_type is Tuple or target_type is tuple:
1377
raise_error_container_parameter_missing("Tuple")
1378
elif target_type is Dict or target_type is dict:
1379
raise_error_container_parameter_missing("Dict")
1380
elif target_type is None or target_type is Optional:
1381
raise_error_container_parameter_missing("Optional")
1384
def check_empty_containers(obj) -> None:
1385
if obj == [] or obj == {} or obj == ():
1387
"The inner type of a container is lost when "
1388
"calling torch.jit.isinstance in eager mode. For "
1389
"example, List[int] would become list and "
1390
"therefore falsely return True for List[float] or"
1397
def container_checker(obj, target_type) -> bool:
1398
origin_type = get_origin(target_type)
1399
check_args_exist(target_type)
1400
if origin_type is None:
1402
elif origin_type is list or origin_type is List:
1403
check_empty_containers(obj)
1404
if not isinstance(obj, list):
1406
arg_type = get_args(target_type)[0]
1407
arg_origin = get_origin(arg_type)
1411
if not container_checker(el, arg_type):
1413
elif not isinstance(el, arg_type):
1416
elif origin_type is Dict or origin_type is dict:
1417
check_empty_containers(obj)
1418
if not isinstance(obj, dict):
1420
key_type = get_args(target_type)[0]
1421
val_type = get_args(target_type)[1]
1422
for key, val in obj.items():
1424
if not isinstance(key, key_type):
1426
val_origin = get_origin(val_type)
1428
if not container_checker(val, val_type):
1430
elif not isinstance(val, val_type):
1433
elif origin_type is Tuple or origin_type is tuple:
1434
check_empty_containers(obj)
1435
if not isinstance(obj, tuple):
1437
arg_types = get_args(target_type)
1438
if len(obj) != len(arg_types):
1440
for el, el_type in zip(obj, arg_types):
1441
el_origin = get_origin(el_type)
1443
if not container_checker(el, el_type):
1445
elif not isinstance(el, el_type):
1448
elif origin_type is Union or issubclass(
1449
origin_type, BuiltinUnionType
1453
inner_types = get_args(target_type)
1454
for t in inner_types:
1455
t_origin = get_origin(t)
1457
return container_checker(obj, t)
1458
elif isinstance(obj, t):
1463
def _isinstance(obj, target_type) -> bool:
1464
if isinstance(target_type, collections.abc.Container):
1465
if not isinstance(target_type, tuple):
1467
"The second argument to "
1468
"`torch.jit.isinstance` must be a type "
1469
"or a tuple of types"
1471
for t_type in target_type:
1472
if _isinstance(obj, t_type):
1476
origin_type = get_origin(target_type)
1478
return container_checker(obj, target_type)
1482
check_args_exist(target_type)
1485
return isinstance(obj, target_type)
1488
class _TensorExtractor(pickle.Pickler):
1489
def __init__(self, *args, tensors: List[torch.Tensor], **kwargs):
1490
super().__init__(*args, **kwargs)
1491
self.tensors = tensors
1493
def persistent_id(self, obj):
1494
if isinstance(obj, torch.Tensor):
1495
self.tensors.append(obj)
1503
if isinstance(obj, LockType):
1507
if isinstance(obj, CFuture) or is_rref_instance(obj):
1509
if isinstance(obj, CAwait):
1511
if isinstance(obj, torch.cuda.Event):
1513
if isinstance(obj, threading.Thread):
1518
def _extract_tensors(obj):
1520
This function is exclusively called from C++.
1521
See ``torch/csrc/jit/python/python_ivalue.h``.
1523
It extracts the tensors contained in the given object, through pickling.
1525
tensors: List[torch.Tensor] = []
1526
extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors)
1531
def _get_model_id(obj) -> Optional[str]:
1532
if isinstance(obj, torch.jit.ScriptModule):
1533
return str(obj._c._type())
1534
elif isinstance(obj, torch.jit.ScriptFunction):
1535
return obj.qualified_name
1543
if sys.version_info > (3, 10):
1544
_drop(enum.Enum.__new__)
1545
_drop(enum.Enum.__format__)
1546
_drop(enum.Enum.__repr__)
1547
_drop(enum.Enum.__str__)