6
from collections import defaultdict
7
from typing import Any, DefaultDict, List, Optional
12
def _type(self, dtype=None, non_blocking=False, **kwargs):
13
"""Returns the type if `dtype` is not provided, else casts this object to
16
If this is already of the correct type, no copy is performed and the
17
original object is returned.
20
dtype (type or string): The desired type
21
non_blocking (bool): If ``True``, and the source is in pinned memory
22
and destination is on the GPU or vice versa, the copy is performed
23
asynchronously with respect to the host. Otherwise, the argument
25
**kwargs: For compatibility, may contain the key ``async`` in place of
26
the ``non_blocking`` argument. The ``async`` arg is deprecated.
28
non_blocking = _get_async_or_non_blocking("type", non_blocking, kwargs)
30
return self.__module__ + "." + self.__class__.__name__
32
if isinstance(dtype, str):
33
dtype = _import_dotted_name(dtype)
34
if dtype == type(self):
37
if not dtype.is_sparse:
38
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
39
new_module_name = dtype.__module__.replace(".sparse", "")
40
new_values_type_name = new_module_name + "." + dtype.__name__
41
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
42
new_indices_type_name = new_module_name + ".LongTensor"
43
new_indices = torch.Tensor._indices(self).type(
44
new_indices_type_name, non_blocking
46
return dtype(new_indices, new_values, self.size())
48
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
49
return dtype(self.size()).copy_(self, non_blocking)
52
def _hpu(self, device=None, non_blocking=False, **kwargs):
53
"""Returns a copy of this object in HPU memory.
55
If this object is already in HPU memory and on the correct device, then
56
no copy is performed and the original object is returned.
59
device (int): The destination HPU id. Defaults to the current device.
60
non_blocking (bool): If ``True`` and the source is in pinned memory,
61
the copy will be asynchronous with respect to the host. Otherwise,
62
the argument has no effect.
63
**kwargs: For compatibility, may contain the key ``async`` in place of
64
the ``non_blocking`` argument.
66
non_blocking = _get_async_or_non_blocking("hpu", non_blocking, kwargs)
67
hpu = getattr(torch, "hpu", None)
68
assert hpu is not None, "HPU device module is not loaded"
71
device = hpu.current_device()
72
if self.get_device() == device:
77
with hpu.device(device):
78
assert not self.is_sparse, "sparse storage is not supported for HPU tensors"
79
untyped_storage = torch.UntypedStorage(self.size(), device=torch.device("hpu"))
80
untyped_storage.copy_(self, non_blocking)
81
return untyped_storage
84
def _cuda(self, device=None, non_blocking=False, **kwargs):
85
"""Returns a copy of this object in CUDA memory.
87
If this object is already in CUDA memory and on the correct device, then
88
no copy is performed and the original object is returned.
91
device (int): The destination GPU id. Defaults to the current device.
92
non_blocking (bool): If ``True`` and the source is in pinned memory,
93
the copy will be asynchronous with respect to the host. Otherwise,
94
the argument has no effect.
95
**kwargs: For compatibility, may contain the key ``async`` in place of
96
the ``non_blocking`` argument.
98
non_blocking = _get_async_or_non_blocking("cuda", non_blocking, kwargs)
101
device = torch.cuda.current_device()
102
if self.get_device() == device:
107
with torch.cuda.device(device):
109
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
110
indices = torch.Tensor._indices(self).cuda(device, non_blocking)
111
values = torch.Tensor._values(self).cuda(device, non_blocking)
112
return new_type(indices, values, self.size())
114
untyped_storage = torch.UntypedStorage(
115
self.size(), device=torch.device("cuda")
117
untyped_storage.copy_(self, non_blocking)
118
return untyped_storage
121
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
122
"""Return the non-blocking flag given the function name and kwargs.
125
function_name (str): the name of the function being used.
126
non_blocking (bool): the default value.
127
**kwargs (dict): the kwargs passed to the function.
131
if len(kwargs) != 1 or "async" not in kwargs:
132
message = "{}() got an unexpected keyword argument '{}'"
133
argument = list(kwargs.keys()).pop()
134
raise TypeError(message.format(function_name, argument))
135
warnings.warn("'async' is deprecated; use 'non_blocking'")
136
return kwargs["async"]
178
def _rebuild_tensor(storage, storage_offset, size, stride):
180
t = torch.empty((0,), dtype=storage.dtype, device=storage._untyped_storage.device)
181
return t.set_(storage._untyped_storage, storage_offset, size, stride)
184
def get_tensor_metadata(tensor):
188
assert isinstance(tensor, torch.Tensor)
189
return torch._C._get_tensor_metadata(tensor)
192
def set_tensor_metadata(tensor, metadata):
194
assert isinstance(metadata, dict)
195
assert isinstance(tensor, torch.Tensor)
196
torch._C._set_tensor_metadata(tensor, metadata)
199
def _rebuild_tensor_v2(
200
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None
202
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
203
tensor.requires_grad = requires_grad
205
set_tensor_metadata(tensor, metadata)
210
tensor._backward_hooks = backward_hooks
214
def _rebuild_tensor_v3(
227
device=storage._untyped_storage.device,
228
requires_grad=requires_grad,
230
t.set_(storage._untyped_storage, storage_offset, size, stride)
232
set_tensor_metadata(t, metadata)
233
t._backward_hooks = backward_hooks
237
_sparse_tensors_to_validate: List["torch.Tensor"] = []
250
def _validate_loaded_sparse_tensors():
252
for t in _sparse_tensors_to_validate:
253
if t.layout is torch.sparse_coo:
254
torch._validate_sparse_coo_tensor_args(
255
t._indices(), t._values(), t.size(), t.is_coalesced()
265
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
266
compressed_indices, plain_indices = (
271
compressed_indices, plain_indices = (
275
torch._validate_sparse_compressed_tensor_args(
276
compressed_indices, plain_indices, t.values(), t.size(), t.layout
279
raise NotImplementedError(
280
f"_validate_loaded_sparse_tensors for layout `{t.layout}`"
284
_sparse_tensors_to_validate.clear()
287
def _rebuild_sparse_tensor(layout, data):
289
Rebuilds a sparse tensor from its sparse storage representation.
292
layout (str): The sparse storage layout of the tensor.
293
data (tuple): The tensor's sparse storage representation.
295
if layout == torch.sparse_coo:
298
indices, values, size = data
301
indices, values, size, is_coalesced = data
302
result = torch.sparse_coo_tensor(
303
indices, values, size, check_invariants=False, is_coalesced=is_coalesced
305
_sparse_tensors_to_validate.append(result)
314
compressed_indices, plain_indices, values, size = data
315
result = torch.sparse_compressed_tensor(
321
check_invariants=False,
323
_sparse_tensors_to_validate.append(result)
326
raise NotImplementedError(f"rebuilding sparse tensor for layout {layout}")
329
def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
330
return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets)
333
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
334
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
335
tensor.requires_grad = requires_grad
340
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
343
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
344
return torch.empty_strided(
345
size, stride, dtype=dtype, device="meta", requires_grad=requires_grad
349
def _rebuild_wrapper_subclass(
350
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
352
return torch.Tensor._make_wrapper_subclass(
356
storage_offset=storage_offset,
359
requires_grad=requires_grad,
374
qscheme = quantizer_params[0]
375
if qscheme == torch.per_tensor_affine:
376
_, scale, zero_point = quantizer_params
377
tensor = torch._empty_affine_quantized(
380
zero_point=zero_point,
382
device=storage.device,
384
elif qscheme in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
385
_, scales, zero_points, axis = quantizer_params
386
if type(scales) is list and type(zero_points) is list:
387
if qscheme == torch.per_channel_affine:
388
scales = torch.tensor(scales, dtype=torch.double, device=storage.device)
389
zero_points = torch.tensor(
390
zero_points, dtype=torch.long, device=storage.device
393
scales = torch.tensor(scales, dtype=torch.float, device=storage.device)
394
zero_points = torch.tensor(
395
zero_points, dtype=torch.float, device=storage.device
397
tensor = torch._empty_per_channel_affine_quantized(
400
zero_points=zero_points,
403
device=storage.device,
406
raise RuntimeError(f"Can't deserialize quantized tensor with qscheme {qscheme}")
407
tensor.set_(storage, storage_offset, size, stride)
408
tensor.requires_grad = requires_grad
412
tensor._backward_hooks = backward_hooks
416
def _rebuild_parameter(data, requires_grad, backward_hooks):
417
param = torch.nn.Parameter(data, requires_grad)
421
param._backward_hooks = backward_hooks
426
def _rebuild_parameter_with_state(data, requires_grad, backward_hooks, state):
427
param = torch.nn.Parameter(data, requires_grad)
431
param._backward_hooks = backward_hooks
434
param = _set_obj_state(param, state)
438
def _get_obj_state(obj):
445
getstate_fn = getattr(obj, "__getstate__", None)
447
state = getstate_fn()
449
slots_to_save = copyreg._slotnames(obj.__class__)
454
name: getattr(obj, name)
455
for name in slots_to_save
456
if hasattr(obj, name)
465
def _set_obj_state(obj, state):
466
if isinstance(state, tuple):
467
if not len(state) == 2:
468
raise RuntimeError(f"Invalid serialized state: {state}")
469
dict_state = state[0]
470
slots_state = state[1]
478
for k, v in dict_state.items():
482
for k, v in slots_state.items():
487
def _import_dotted_name(name):
488
components = name.split(".")
489
obj = __import__(components[0])
490
for component in components[1:]:
491
obj = getattr(obj, component)
495
def _flatten_dense_tensors(tensors):
496
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
499
Since inputs are dense, the resulting tensor will be a concatenated 1D
500
buffer. Element-wise operation on this buffer will be equivalent to
501
operating individually.
504
tensors (Iterable[Tensor]): dense tensors to flatten.
507
A contiguous 1D buffer containing input tensors.
509
return torch._C._nn.flatten_dense_tensors(tensors)
512
def _flatten_sparse_tensors(tensors):
513
"""Flatten sparse tensors into two contiguous 1D buffers, one of indices and
514
one of values. Assume tensors are of same sparse type.
517
tensors (Iterable[Tensor]): sparse tensors to flatten.
520
A tuple of two contiguous 1D buffers, one containing input tensors'
521
indices and the other containing the values.
523
flat_indices = torch._C._nn.flatten_dense_tensors(
524
[torch.Tensor._indices(t) for t in tensors]
526
flat_values = torch._C._nn.flatten_dense_tensors(
527
[torch.Tensor._values(t) for t in tensors]
529
return flat_indices, flat_values
532
def _unflatten_dense_tensors(flat, tensors):
533
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
534
same dense type, and that flat is given by _flatten_dense_tensors.
537
flat (Tensor): flattened dense tensors to unflatten.
538
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
542
Unflattened dense tensors with sizes same as tensors and values from
545
return torch._C._nn.unflatten_dense_tensors(flat, tensors)
548
def _unflatten_sparse_tensors(flat, tensors):
549
"""View flat buffer (containing indices and values) using the sizes of
550
tensors. Assume that tensors are of same sparse type, and that flat is given
551
by _flatten_sparse_tensors.
554
flat (tuple(Tensor, Tensor)): flattened indices and values of sparse
555
tensors to unflatten.
556
tensors (Iterable[Tensor]): sparse tensors whose sizes will be used to
560
Unflattened sparse tensors with sizes same as tensors and values from
563
flat_indices, flat_values = flat
564
indices = torch._C._nn.unflatten_dense_tensors(
565
flat_indices, [torch.Tensor._indices(t) for t in tensors]
567
values = torch._C._nn.unflatten_dense_tensors(
568
flat_values, [torch.Tensor._values(t) for t in tensors]
571
for t, i, v in zip(tensors, indices, values):
572
outputs.append(t.new(i, v, t.size()))
573
return tuple(outputs)
576
def _reorder_tensors_as(tensors, ordered_tensors):
577
"""Assume that tensors are of same order as ordered_tensors within their
578
types, e.g., from _take_tensors. Reorder them to be of same order as
582
tensors (Iterable[Tensor]): tensors to be reordered. They should be of
583
the same order as ordered_tensors within their own types.
584
ordered_tensors (Iterable[Tensor]): tensors whose order will be the
588
Ordered tuple of tensors with contents from tensors and order of
591
type_dict = defaultdict(list)
592
for tensor in tensors:
593
type_dict[tensor.type()].append(tensor)
594
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
595
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
598
def _take_tensors(tensors, size_limit):
599
"""Group tensors into chunks. This generator yields a chunk at each time,
600
each containing tensors of same type up to certain byte limit in total size.
603
tensors (Sequence): A sequence of tensors to be separated into chunks.
604
size_limit (int): The limit of each chunk in bytes.
607
Blocks of tensors of same type and within size_limit. The yielded
608
tensors are only ordered as the original sequence within its types.
610
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
611
for tensor in tensors:
614
indices = torch.Tensor._indices(tensor)
615
values = torch.Tensor._values(tensor)
617
indices.numel() * indices.element_size()
618
+ values.numel() * values.element_size()
621
size = tensor.numel() * tensor.element_size()
622
buf_and_size = buf_dict[t]
623
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
624
yield buf_and_size[0]
625
buf_and_size = buf_dict[t] = [[], 0]
626
buf_and_size[0].append(tensor)
627
buf_and_size[1] += size
628
for buf, _ in buf_dict.values():
635
def annotate(ret, **kwargs):
637
fun.__annotations__ = dict(kwargs)
638
fun.__annotations__["return"] = ret
644
def render_call(fn, args, kwargs):
645
str_fn = torch.overrides.resolve_name(fn)
649
str_args: List[str] = []
650
with torch._tensor_str.printoptions(threshold=0, edgeitems=0):
651
str_args.extend(repr(a) for a in args)
652
str_args.extend(f"{k}={repr(v)}" for k, v in kwargs.items())
653
r = f"{str_fn}({', '.join(str_args)})"
666
class KeyErrorMessage(str):
667
r"""str subclass that returns itself in repr"""
673
class ExceptionWrapper:
674
r"""Wraps an exception plus traceback to communicate across threads"""
676
def __init__(self, exc_info=None, where="in background"):
680
exc_info = sys.exc_info()
681
self.exc_type = exc_info[0]
682
self.exc_msg = "".join(traceback.format_exception(*exc_info))
686
r"""Reraises the wrapped exception in the current thread"""
689
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}"
690
if self.exc_type == KeyError:
694
msg = KeyErrorMessage(msg)
695
elif getattr(self.exc_type, "message", None):
698
raise self.exc_type(message=msg)
700
exception = self.exc_type(msg)
704
raise RuntimeError(msg) from None
708
def _get_available_device_type():
709
if torch.cuda.is_available():
711
if hasattr(torch, "xpu") and torch.xpu.is_available():
713
custom_backend_name = torch._C._get_privateuse1_backend_name()
714
custom_device_mod = getattr(torch, custom_backend_name, None)
715
if custom_device_mod and custom_device_mod.is_available():
716
return custom_backend_name
721
def _get_device_attr(get_member):
722
device_type = _get_available_device_type()
723
if device_type and device_type.lower() == "cuda":
724
return get_member(torch.cuda)
725
if device_type and device_type.lower() == "xpu":
726
return get_member(torch.xpu)
727
if device_type == torch._C._get_privateuse1_backend_name():
728
return get_member(getattr(torch, device_type))
733
def _get_current_device_index():
735
return _get_device_attr(lambda m: m.current_device())
738
def _get_all_device_indices():
740
return _get_device_attr(lambda m: list(range(m.device_count())))
743
def _get_devices_properties(device_ids):
745
return [_get_device_attr(lambda m: m.get_device_properties(i)) for i in device_ids]
748
def get_current_device_index() -> int:
749
r"""Checks if there are CUDA devices available and
750
returns the device index of the current default CUDA device.
751
Returns -1 in case there are no CUDA devices available.
754
if torch.cuda.device_count() > 0:
755
return torch.cuda.current_device()
759
def _get_device_index(
760
device: Any, optional: bool = False, allow_cpu: bool = False
762
r"""Gets the device index from :attr:`device`, which can be a torch.device
763
object, a Python integer, or ``None``.
765
If :attr:`device` is a torch.device object, returns the device index if it
766
has index. Note that for a device without a specified index,
767
i.e., ``torch.device('xxx')``, this will return the current default
768
device of that type if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
769
CPU devices will be accepted and ``-1`` will be returned in this case.
771
If :attr:`device` is a Python integer, it is returned as is.
773
If :attr:`device` is ``None``, this will return the current default
774
device of the supported runtime platform if :attr:`optional` is ``True``.
775
i.e., the current default CUDA device will be returned if CUDA runtime is supported.
777
if isinstance(device, str):
778
device = torch.device(device)
779
device_idx: Optional[int] = None
780
if isinstance(device, torch.device):
781
if not allow_cpu and device.type == "cpu":
782
raise ValueError(f"Expected a non cpu device, but got: {device}")
783
device_idx = -1 if device.type == "cpu" else device.index
784
if isinstance(device, int):
786
if device_idx is None:
793
if torch.jit.is_scripting():
794
device_idx = get_current_device_index()
796
device_idx = _get_current_device_index()
799
f"Expected a torch.device with a specified index or an integer, but got:{device}"
804
def _handle_complex(tensor):
806
Returns a real view of a tensor if complex dtype else just the tensor
807
need to check if a UninitializedParameter because otherwise checking is_complex is an error for a LazyModule
810
torch.view_as_real(tensor)
811
if not isinstance(tensor, torch.nn.UninitializedParameter)
812
and tensor.is_complex()
817
def _element_size(dtype):
819
Returns the element size for a dtype, in bytes
821
if not isinstance(dtype, torch.dtype):
822
raise RuntimeError(f"expected torch.dtype, but got {type(dtype)}")
825
return torch.finfo(dtype).bits >> 2
826
elif dtype.is_floating_point:
827
return torch.finfo(dtype).bits >> 3
828
elif dtype == torch.bool:
832
return torch.iinfo(dtype).bits >> 3
835
class _ClassPropertyDescriptor:
836
def __init__(self, fget, fset=None):
839
def __get__(self, instance, owner=None):
841
owner = type(instance)
842
return self.fget.__get__(instance, owner)()
845
def classproperty(func):
846
if not isinstance(func, (classmethod, staticmethod)):
847
func = classmethod(func)
848
return _ClassPropertyDescriptor(func)
851
def is_compiling() -> bool:
853
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
855
TODO(khabinov): we should deprecate this function and use torch.compiler.is_compiling().
857
return torch.compiler.is_compiling()
860
def _functionalize_sync(t):
863
from torch._subclasses.functional_tensor import FunctionalTensor
865
if isinstance(t, FunctionalTensor):
878
maybe_functional_mode = torch._C._unset_dispatch_mode(
879
torch._C._TorchDispatchModeKey.FUNCTIONAL
882
torch._functionalize_sync(t.elem)
884
if maybe_functional_mode is not None:
885
torch._C._set_dispatch_mode(maybe_functional_mode)
887
torch._functionalize_sync(t)
890
@functools.lru_cache(2)
891
def _get_device_module(device_type: str):
892
device_module = getattr(torch, device_type, None)
893
if device_module is None:
895
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
900
def _dummy_type(name: str) -> type:
901
def get_err_fn(is_init: bool):
902
def err_fn(obj, *args, **kwargs):
904
class_name = obj.__class__.__name__
906
class_name = obj.__name__
907
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
912
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
916
class _LazySeedTracker:
922
self.manual_seed_all_cb = None
923
self.manual_seed_cb = None
926
def queue_seed_all(self, cb, traceback):
927
self.manual_seed_all_cb = (cb, traceback)
929
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
931
def queue_seed(self, cb, traceback):
932
self.manual_seed_cb = (cb, traceback)
934
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
936
def get_calls(self) -> List:
937
return self.call_order