1
# mypy: allow-untyped-defs
8
from collections import defaultdict
9
from typing import Any, Callable, DefaultDict, Generic, List, Optional
10
from typing_extensions import ParamSpec
15
def _type(self, dtype=None, non_blocking=False, **kwargs):
16
"""Returns the type if `dtype` is not provided, else casts this object to
19
If this is already of the correct type, no copy is performed and the
20
original object is returned.
23
dtype (type or string): The desired type
24
non_blocking (bool): If ``True``, and the source is in pinned memory
25
and destination is on the GPU or vice versa, the copy is performed
26
asynchronously with respect to the host. Otherwise, the argument
28
**kwargs: For compatibility, may contain the key ``async`` in place of
29
the ``non_blocking`` argument. The ``async`` arg is deprecated.
31
non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
33
return self.__module__ + "." + self.__class__.__name__
35
if isinstance(dtype, str):
36
dtype = _import_dotted_name(dtype)
37
if dtype == type(self):
40
if not dtype.is_sparse:
41
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
42
new_module_name = dtype.__module__.replace(".sparse", "")
43
new_values_type_name = new_module_name + "." + dtype.__name__
44
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
45
new_indices_type_name = new_module_name + ".LongTensor"
46
new_indices = torch.Tensor._indices(self).type(
47
new_indices_type_name, non_blocking
49
return dtype(new_indices, new_values, self.size())
51
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
52
return dtype(self.size()).copy_(self, non_blocking)
55
def _to(self, device, non_blocking=False):
56
"""Returns a copy of this object in device memory.
58
If this object is already on the correct device, then no copy is performed
59
and the original object is returned.
62
device (int): The destination device.
63
non_blocking (bool): If ``True`` and the source is in pinned memory,
64
the copy will be asynchronous with respect to the host. Otherwise,
65
the argument has no effect.
67
if self.device == device:
70
device_module = getattr(torch, device.type, None)
72
device_module is not None
73
), f"{device.type.upper()} device module is not loaded"
74
with device_module.device(device):
75
if self.is_sparse and hasattr(device_module, "sparse"):
76
new_type = getattr(device_module.sparse, self.__class__.__name__)
77
indices = getattr(torch.Tensor._indices(self), device.type)(
80
values = getattr(torch.Tensor._values(self), device.type)(
83
return new_type(indices, values, self.size())
87
), f"sparse storage is not supported for {device.type.upper()} tensors"
88
untyped_storage = torch.UntypedStorage(self.size(), device=device)
89
untyped_storage.copy_(self, non_blocking)
90
return untyped_storage
93
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
94
"""Return the non-blocking flag given the function name and kwargs.
97
function_name (str): the name of the function being used.
98
non_blocking (bool): the default value.
99
**kwargs (dict): the kwargs passed to the function.
103
if len(kwargs) != 1 or "async" not in kwargs:
104
message = "{}() got an unexpected keyword argument '{}'"
105
argument = list(kwargs.keys()).pop()
106
raise TypeError(message.format(function_name, argument))
107
warnings.warn("'async' is deprecated; use 'non_blocking'")
108
return kwargs["async"]
111
def _get_restore_location(device):
112
"""Return the map_location location.
114
Used for rebuild functions where the tensor device is distinct from the storage
117
map_location = torch.serialization._serialization_tls.map_location
118
if map_location is None:
121
if isinstance(map_location, dict):
122
return map_location.get(device, device)
123
elif isinstance(map_location, (str, torch.device)):
126
assert callable(map_location)
128
"Callable map_location not supported with _rebuild_wrapper_subclass "
129
"or _rebuild_device_tensor_from_numpy"
133
# Note [Don't serialize hooks]
134
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
135
# Since time immemorial, we have serialized the backward hooks associated with
136
# variables. This kind of half-worked--Python can pickle global functions
137
# (but not closures!)--but there were problems.
139
# - It's fragile. If you serialize a backward hook into a saved
140
# model, and then you rename the function associated with the hook,
141
# now your saved model is broken and you can't load it anymore.
143
# - It's not actually used. The standard recommendation is to
144
# serialize the *state_dict* of a model, not the model itself
145
# (since this is more stable to code changes affecting the model
146
# serialization), and the state dict saves "data" only, thus
147
# stripping the backward hooks. In some cases, hooks are
148
# essential to the well-functioning of a model (e.g., DDP),
149
# but DDP already manages readding the hooks!
151
# - We didn't serialize them in many cases. Prior to #10220, we
152
# were dropping backward hooks in ForkingPickler. We "fixed" this
153
# to be convenient with other serialization sites, but lack of
154
# serializing backward hooks wasn't actually the root cause of
157
# With these cases in mind, we have decided that a better strategy
158
# is to just NOT serialize hooks at all.
160
# Since this is a BC-breaking change, we should warn when we previously
161
# serialized a hook, but no longer do so. This will be done by adding a special
162
# sentinel property to hooks will be used to suppress this warning. If a hook
163
# has the property _torch_serialize_ignore, we will not emit a warning if we
164
# attempt to serialize a Tensor with this hook attached to it.
166
# By the way, when _backward_hooks is skipped, we must give an EMPTY
167
# OrderedDict(), if you pass a None you'll run afoul #12219.
170
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
172
def _rebuild_tensor(storage, storage_offset, size, stride):
173
# first construct a tensor with the correct dtype/device
174
t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
175
return t.set_(storage._untyped_storage, storage_offset, size, stride)
178
def get_tensor_metadata(tensor):
179
# Tensor's Metadata for serializing.
180
# Currently, this only returns a dict[string, bool] specifing whether
181
# `conj` or `neg` bit is set.
182
assert isinstance(tensor, torch.Tensor)
183
return torch._C._get_tensor_metadata(tensor) # type: ignore[attr-defined]
186
def set_tensor_metadata(tensor, metadata):
187
# See `get_tensor_metadata` above
188
assert isinstance(metadata, dict)
189
assert isinstance(tensor, torch.Tensor)
190
torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined]
193
def _rebuild_tensor_v2(
202
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
203
tensor.requires_grad = requires_grad
205
set_tensor_metadata(tensor, metadata)
207
# NB: This line exists only for backwards compatibility; the
208
# general expectation is that backward_hooks is an empty
209
# OrderedDict. See Note [Don't serialize hooks]
210
tensor._backward_hooks = backward_hooks
214
def _rebuild_tensor_v3(
227
device=storage._untyped_storage.device,
228
requires_grad=requires_grad,
230
t.set_(storage._untyped_storage, storage_offset, size, stride)
232
set_tensor_metadata(t, metadata)
233
t._backward_hooks = backward_hooks
237
_sparse_tensors_to_validate: List["torch.Tensor"] = []
240
# In _legacy_load() in serialization.py we unpickle storages after the sparse
241
# tensors have been already unpickled. Those storages contain data necessary for
242
# validating sparse tensors: indices and values. That's why sparse tensors are
243
# first unpickled without any validation, and then this function is called just
244
# before _legacy_load() returns, so that all the sparse tensors can be validated
247
# The same procedure must be followed by _load() in serialization.py because due
248
# to Pickler semantics, we have to use the same (non-validating) function for
249
# unpickling sparse tensors, regardless of the caller.
250
def _validate_loaded_sparse_tensors():
252
for t in _sparse_tensors_to_validate:
253
if t.layout is torch.sparse_coo:
254
torch._validate_sparse_coo_tensor_args(
255
t._indices(), t._values(), t.size(), t.is_coalesced()
263
# TODO: Validation currently involves an expensive traversal
264
# on CPU, which may include a device transfer.
265
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
266
compressed_indices, plain_indices = (
271
compressed_indices, plain_indices = (
275
torch._validate_sparse_compressed_tensor_args(
276
compressed_indices, plain_indices, t.values(), t.size(), t.layout
279
raise NotImplementedError(
280
f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
284
_sparse_tensors_to_validate.clear()
287
def _rebuild_sparse_tensor(layout, data):
289
Rebuilds a sparse tensor from its sparse storage representation.
292
layout (str): The sparse storage layout of the tensor.
293
data (tuple): The tensor's sparse storage representation.
295
if layout == torch.sparse_coo:
298
indices, values, size = data
301
indices, values, size, is_coalesced = data
302
result = torch.sparse_coo_tensor(
303
indices, values, size, check_invariants=False, is_coalesced=is_coalesced
305
_sparse_tensors_to_validate.append(result)
314
compressed_indices, plain_indices, values, size = data
315
result = torch.sparse_compressed_tensor(
321
check_invariants=False,
323
_sparse_tensors_to_validate.append(result)
326
raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
329
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
330
return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
333
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
334
device = _get_restore_location(device)
335
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
336
tensor.requires_grad = requires_grad
340
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
341
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
344
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
345
return torch.empty_strided(
346
size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
350
def _rebuild_wrapper_subclass(
360
device = _get_restore_location(device)
361
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
366
storage_offset=storage_offset,
369
requires_grad=requires_grad,
373
# TODO: Once we decide to break serialization FC, `storage` no longer needs to
384
qscheme = quantizer_params[0]
385
if qscheme == torch.per_tensor_affine:
386
_, scale, zero_point = quantizer_params
387
tensor = torch._empty_affine_quantized(
390
zero_point=zero_point,
392
device=storage.device,
394
elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
395
_, scales, zero_points, axis = quantizer_params
396
if type(scales) is list and type(zero_points) is list:
397
if qscheme == torch.per_channel_affine:
398
scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
399
zero_points = torch.tensor(
400
zero_points, dtype=torch.long, device=storage.device
403
scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
404
zero_points = torch.tensor(
405
zero_points, dtype=torch.float, device=storage.device
407
tensor = torch._empty_per_channel_affine_quantized(
410
zero_points=zero_points,
413
device=storage.device,
416
raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
417
tensor.set_(storage, storage_offset, size, stride)
418
tensor.requires_grad = requires_grad
419
# NB: This line exists only for backwards compatibility; the
420
# general expectation is that backward_hooks is an empty
421
# OrderedDict. See Note [Don't serialize hooks]
422
tensor._backward_hooks = backward_hooks
426
def _rebuild_parameter(data, requires_grad, backward_hooks):
427
param = torch.nn.Parameter(data, requires_grad)
428
# NB: This line exists only for backwards compatibility; the
429
# general expectation is that backward_hooks is an empty
430
# OrderedDict. See Note [Don't serialize hooks]
431
param._backward_hooks = backward_hooks
436
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
437
param = torch.nn.Parameter(data, requires_grad)
438
# NB: This line exists only for backwards compatibility; the
439
# general expectation is that backward_hooks is an empty
440
# OrderedDict. See Note [Don't serialize hooks]
441
param._backward_hooks = backward_hooks
443
# Restore state on Parameter like python attr.
444
param = _set_obj_state(param, state)
448
def _get_obj_state(obj):
449
# Get the state of the python subclass
450
# This loosely mimicks the function on the object class but since Tensor do not inherit
451
# from it, we cannot call that function directly
452
# https://github.com/python/cpython/blob/c83919bd635f4433f1c6ae8504996a9fe3c215e5/Objects/typeobject.c#L4891
453
# Note that starting with Python 3.11, this `__getstate__` is always defined and thus
454
# the else branch will never be taken.
455
getstate_fn = getattr(obj, "__getstate__", None)
457
state = getstate_fn()
459
slots_to_save = copyreg._slotnames(obj.__class__) # type: ignore[attr-defined]
464
name: getattr(obj, name)
465
for name in slots_to_save
466
if hasattr(obj, name)
475
def _set_obj_state(obj, state):
476
if isinstance(state, tuple):
477
if not len(state) == 2:
478
raise RuntimeError(f"Invalid serialized state: {state}")
479
dict_state = state[0]
480
slots_state = state[1]
485
# Starting with Python 3.11, the __dict__ attribute is lazily created
486
# and is serialized as None when not needed.
488
for k, v in dict_state.items():
492
for k, v in slots_state.items():
497
def _import_dotted_name(name):
498
components = name.split(".")
499
obj = __import__(components[0])
500
for component in components[1:]:
501
obj = getattr(obj, component)
505
def _flatten_dense_tensors(tensors):
506
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
509
Since inputs are dense, the resulting tensor will be a concatenated 1D
510
buffer. Element-wise operation on this buffer will be equivalent to
511
operating individually.
514
tensors (Iterable[Tensor]): dense tensors to flatten.
517
A contiguous 1D buffer containing input tensors.
519
return torch._C._nn.flatten_dense_tensors(tensors)
522
def _flatten_sparse_tensors(tensors):
523
"""Flatten sparse tensors into two contiguous 1D buffers, one of indices and
524
one of values. Assume tensors are of same sparse type.
527
tensors (Iterable[Tensor]): sparse tensors to flatten.
530
A tuple of two contiguous 1D buffers, one containing input tensors'
531
indices and the other containing the values.
533
flat_indices = torch._C._nn.flatten_dense_tensors(
534
[torch.Tensor._indices(t) for t in tensors]
536
flat_values = torch._C._nn.flatten_dense_tensors(
537
[torch.Tensor._values(t) for t in tensors]
539
return flat_indices, flat_values
542
def _unflatten_dense_tensors(flat, tensors):
543
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
544
same dense type, and that flat is given by _flatten_dense_tensors.
547
flat (Tensor): flattened dense tensors to unflatten.
548
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
552
Unflattened dense tensors with sizes same as tensors and values from
555
return torch._C._nn.unflatten_dense_tensors(flat, tensors)
558
def _unflatten_sparse_tensors(flat, tensors):
559
"""View flat buffer (containing indices and values) using the sizes of
560
tensors. Assume that tensors are of same sparse type, and that flat is given
561
by _flatten_sparse_tensors.
564
flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
565
tensors to unflatten.
566
tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
570
Unflattened sparse tensors with sizes same as tensors and values from
573
flat_indices, flat_values = flat
574
indices = torch._C._nn.unflatten_dense_tensors(
575
flat_indices, [torch.Tensor._indices(t) for t in tensors]
577
values = torch._C._nn.unflatten_dense_tensors(
578
flat_values, [torch.Tensor._values(t) for t in tensors]
581
for t, i, v in zip(tensors, indices, values):
582
outputs.append(t.new(i, v, t.size()))
583
return tuple(outputs)
586
def _reorder_tensors_as(tensors, ordered_tensors):
587
"""Assume that tensors are of same order as ordered_tensors within their
588
types, e.g., from _take_tensors. Reorder them to be of same order as
592
tensors (Iterable[Tensor]): tensors to be reordered. They should be of
593
the same order as ordered_tensors within their own types.
594
ordered_tensors (Iterable[Tensor]): tensors whose order will be the
598
Ordered tuple of tensors with contents from tensors and order of
601
type_dict = defaultdict(list)
602
for tensor in tensors:
603
type_dict[tensor.type()].append(tensor)
604
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
605
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
608
def _take_tensors(tensors, size_limit):
609
"""Group tensors into chunks. This generator yields a chunk at each time,
610
each containing tensors of same type up to certain byte limit in total size.
613
tensors (Sequence): A sequence of tensors to be separated into chunks.
614
size_limit (int): The limit of each chunk in bytes.
617
Blocks of tensors of same type and within size_limit. The yielded
618
tensors are only ordered as the original sequence within its types.
620
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
621
for tensor in tensors:
624
indices = torch.Tensor._indices(tensor)
625
values = torch.Tensor._values(tensor)
627
indices.numel() * indices.element_size()
628
+ values.numel() * values.element_size()
631
size = tensor.numel() * tensor.element_size()
632
buf_and_size = buf_dict[t]
633
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
634
yield buf_and_size[0]
635
buf_and_size = buf_dict[t] = [[], 0]
636
buf_and_size[0].append(tensor)
637
buf_and_size[1] += size
638
for buf, _ in buf_dict.values():
643
# annotation decorator to get annotations in a way that is compatible
644
# with both Python 2 and 3
645
def annotate(ret, **kwargs):
647
fun.__annotations__ = dict(kwargs)
648
fun.__annotations__["return"] = ret
654
def render_call(fn, args, kwargs):
655
str_fn = torch.overrides.resolve_name(fn)
659
str_args: List[str] = []
660
with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
661
str_args.extend(repr(a) for a in args)
662
str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
663
r = f"{str_fn}({', '.join(str_args)})"
667
# NOTE [ Python Traceback Reference Cycle Problem ]
669
# When using sys.exc_info(), it is important to **not** store the exc_info[2],
670
# which is the traceback, because otherwise you will run into the traceback
671
# reference cycle problem, i.e., the traceback holding reference to the frame,
672
# and the frame (which holds reference to all the object in its temporary scope)
673
# holding reference the traceback.
676
class KeyErrorMessage(str):
677
r"""str subclass that returns itself in repr"""
683
class ExceptionWrapper:
684
r"""Wraps an exception plus traceback to communicate across threads"""
686
def __init__(self, exc_info=None, where="in background"):
687
# It is important that we don't store exc_info, see
688
# NOTE [ Python Traceback Reference Cycle Problem ]
690
exc_info = sys.exc_info()
691
self.exc_type = exc_info[0]
692
self.exc_msg = "".join(traceback.format_exception(*exc_info))
696
r"""Reraises the wrapped exception in the current thread"""
697
# Format a message such as: "Caught ValueError in DataLoader worker
698
# process 2. Original Traceback:", followed by the traceback.
699
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
700
if self.exc_type == KeyError:
701
# KeyError calls repr() on its argument (usually a dict key). This
702
# makes stack traces unreadable. It will not be changed in Python
703
# (https://bugs.python.org/issue2651), so we work around it.
704
msg = KeyErrorMessage(msg)
705
elif getattr(self.exc_type, "message", None):
706
# Some exceptions have first argument as non-str but explicitly
708
raise self.exc_type(message=msg)
710
exception = self.exc_type(msg)
712
# If the exception takes multiple arguments, don't try to
713
# instantiate since we don't know how to
714
raise RuntimeError(msg) from None
718
def _get_available_device_type():
719
if torch.cuda.is_available():
721
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
723
if hasattr(torch, "mtia") and torch.mtia.is_available():
725
custom_backend_name = torch._C._get_privateuse1_backend_name()
726
custom_device_mod = getattr(torch, custom_backend_name, None)
727
if custom_device_mod and custom_device_mod.is_available():
728
return custom_backend_name
729
# add more available device types here
733
def _get_device_attr(get_member):
734
device_type = _get_available_device_type()
735
if device_type and device_type.lower() == "cuda":
736
return get_member(torch.cuda)
737
if device_type and device_type.lower() == "xpu":
738
return get_member(torch.xpu) # type: ignore[attr-defined]
739
if device_type and device_type.lower() == "mtia":
740
return get_member(torch.mtia)
741
if device_type == torch._C._get_privateuse1_backend_name():
742
return get_member(getattr(torch, device_type))
743
# add more available device types here
747
def _get_current_device_index():
748
# current device index
749
return _get_device_attr(lambda m: m.current_device())
752
def _get_all_device_indices():
754
return _get_device_attr(lambda m: list(range(m.device_count())))
757
def _get_devices_properties(device_ids):
758
# all device properties
759
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
762
def get_current_device_index() -> int:
763
r"""Checks if there are CUDA devices available and
764
returns the device index of the current default CUDA device.
765
Returns -1 in case there are no CUDA devices available.
768
if torch.cuda.device_count() > 0:
769
return torch.cuda.current_device()
773
def _get_device_index(
775
optional: bool = False,
776
allow_cpu: bool = False,
778
r"""Gets the device index from :attr:`device`, which can be a torch.device
779
object, a Python integer, or ``None``.
781
If :attr:`device` is a torch.device object, returns the device index if it
782
has index. Note that for a device without a specified index,
783
i.e., ``torch.device('xxx')``, this will return the current default
784
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
785
CPU devices will be accepted and ``-1`` will be returned in this case.
787
If :attr:`device` is a Python integer, it is returned as is.
789
If :attr:`device` is ``None``, this will return the current default
790
device of the supported runtime platform if :attr:`optional` is ``True``.
791
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
793
if isinstance(device, str):
794
device = torch.device(device)
795
device_idx: Optional[int] = None
796
if isinstance(device, torch.device):
797
if not allow_cpu and device.type == "cpu":
798
raise ValueError(f"Expected a non cpu device, but got: {device}")
799
device_idx = -1 if device.type == "cpu" else device.index
800
if isinstance(device, int):
802
if device_idx is None:
804
# The eager API _get_current_device_index uses `lambda` functions which are
805
# not supported in JIT and hence not scriptable. The JIT equivalent API to get
806
# the current device index is `get_current_device_index()` which can
807
# be scripted. We use is_scripting to check the mode we are in and call the
809
if torch.jit.is_scripting():
810
device_idx = get_current_device_index()
812
device_idx = _get_current_device_index()
815
f"Expected a torch.device with a specified index or an integer, but got:{device}"
820
def _handle_complex(tensor):
822
Returns a real view of a tensor if complex dtype else just the tensor
823
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
826
torch.view_as_real(tensor)
827
if not isinstance(tensor, torch.nn.UninitializedParameter)
828
and tensor.is_complex()
833
def _element_size(dtype):
835
Returns the element size for a dtype, in bytes
837
if not isinstance(dtype, torch.dtype):
838
raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
841
return torch.finfo(dtype).bits >> 2
842
elif dtype.is_floating_point:
843
return torch.finfo(dtype).bits >> 3
844
elif dtype == torch.bool:
845
# NOTE: torch.bool is not supported in torch.iinfo()
848
return torch.iinfo(dtype).bits >> 3
851
class _ClassPropertyDescriptor:
852
def __init__(self, fget, fset=None):
855
def __get__(self, instance, owner=None):
857
owner = type(instance)
858
return self.fget.__get__(instance, owner)()
861
def classproperty(func):
862
if not isinstance(func, (classmethod, staticmethod)):
863
func = classmethod(func)
864
return _ClassPropertyDescriptor(func)
867
def is_compiling() -> bool:
869
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
871
TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
873
return torch.compiler.is_compiling()
876
def _functionalize_sync(t):
877
# This code lives in python instead of C++ since conditioning on a certain python subclass
878
# is much more of a pain in C++.
879
from torch._subclasses.functional_tensor import FunctionalTensor
881
if isinstance(t, FunctionalTensor):
882
# If a FunctionalTensorMode is active while syncing, we don't want it to intercept any ops that get called
883
# when we sync our inner tensor.
885
# (1) If there are input mutations in the graph, then they will be re-applied during
886
# AOTAutograd when we call _sync() from inside of our functionalization kernels.
887
# (2) _sync() causes us to regenerate our updated the tensor from the updated base,
888
# which dispatches to a bunch of view ops
889
# (3) The input to these view ops is our inner FunctionalTensorWrapper
890
# (since the sync was called from C++), not the python FunctionalTensor
891
# (4) if a python FunctionalTensorMode is active, it will complain when it intercepts
892
# the view op, since it will see an input that is a C++ FunctionalTensorWrapper
893
# (aka a normal torch.Tensor) instead of a python `FunctionalTensor).
894
maybe_functional_mode = torch._C._unset_dispatch_mode(
895
torch._C._TorchDispatchModeKey.FUNCTIONAL
898
torch._functionalize_sync(t.elem) # type: ignore[attr-defined]
900
if maybe_functional_mode is not None:
901
torch._C._set_dispatch_mode(maybe_functional_mode)
903
torch._functionalize_sync(t) # type: ignore[attr-defined]
906
@functools.lru_cache(2)
907
def _get_device_module(device_type: str):
908
device_module = getattr(torch, device_type, None)
909
if device_module is None:
911
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
916
def _dummy_type(name: str) -> type:
917
def get_err_fn(is_init: bool):
918
def err_fn(obj, *args, **kwargs):
920
class_name = obj.__class__.__name__
922
class_name = obj.__name__
923
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
928
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
932
class _LazySeedTracker:
933
# Since seeding is memory-less, only track the latest seed.
934
# Note: `manual_seed_all` followed by `manual_seed` overwrites
935
# the seed on current device. We track the order of **latest**
936
# calls between these two API.
938
self.manual_seed_all_cb = None
939
self.manual_seed_cb = None
942
def queue_seed_all(self, cb, traceback):
943
self.manual_seed_all_cb = (cb, traceback)
944
# update seed_all to be latest
945
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
947
def queue_seed(self, cb, traceback):
948
self.manual_seed_cb = (cb, traceback)
949
# update seed to be latest
950
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
952
def get_calls(self) -> List:
953
return self.call_order
956
logger = logging.getLogger(__name__)
960
class CallbackRegistry(Generic[P]):
961
def __init__(self, name: str):
963
self.callback_list: List[Callable[P, None]] = []
965
def add_callback(self, cb: Callable[P, None]) -> None:
966
self.callback_list.append(cb)
968
def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
969
for cb in self.callback_list:
972
except Exception as e:
974
"Exception in callback for %s registered with gpu trace", self.name
978
# IMPORT_MAPPING and NAME_MAPPING are adapted from https://github.com/python/cpython/blob/main/Lib/_compat_pickle.py
979
# for use in the weights_only Unpickler.
982
"__builtin__": "builtins",
983
"copy_reg": "copyreg",
986
"_abcoll": "collections.abc",
987
# Non-mutual mappings.
988
"UserDict": "collections",
989
"UserList": "collections",
990
"UserString": "collections",
997
# This contains rename rules that are easy to handle. We ignore the more
998
# complex stuff (e.g. mapping the names in the urllib and types modules).
999
# These rules should be run before import names are fixed.
1001
("__builtin__", "xrange"): ("builtins", "range"),
1002
("__builtin__", "reduce"): ("functools", "reduce"),
1003
("__builtin__", "intern"): ("sys", "intern"),
1004
("__builtin__", "unichr"): ("builtins", "chr"),
1005
("__builtin__", "unicode"): ("builtins", "str"),
1006
("__builtin__", "long"): ("builtins", "int"),
1007
("itertools", "izip"): ("builtins", "zip"),
1008
("itertools", "imap"): ("builtins", "map"),
1009
("itertools", "ifilter"): ("builtins", "filter"),
1010
("itertools", "ifilterfalse"): ("itertools", "filterfalse"),
1011
("itertools", "izip_longest"): ("itertools", "zip_longest"),
1012
("UserDict", "IterableUserDict"): ("collections", "UserDict"),
1013
("UserList", "UserList"): ("collections", "UserList"),
1014
("UserString", "UserString"): ("collections", "UserString"),
1015
# Non-mutual mappings.
1016
("__builtin__", "basestring"): ("builtins", "str"),
1017
("exceptions", "StandardError"): ("builtins", "Exception"),
1018
("UserDict", "UserDict"): ("collections", "UserDict"),