4
from ._utils import _type, _cuda, _hpu
5
from torch.types import Storage
6
from typing import cast, Any, Dict as _Dict, Optional as _Optional, TypeVar, Type, Union
9
from functools import lru_cache
16
except ModuleNotFoundError:
19
_share_memory_lock = threading.Lock()
20
_share_memory_map: _Dict[int, threading.RLock] = {}
22
T = TypeVar('T', bound='Union[_StorageBase, TypedStorage]')
25
is_sparse: bool = False
26
is_sparse_csr: bool = False
29
def __init__(self, *args, **kwargs): ...
30
def __len__(self) -> int: ...
31
def __getitem__(self, idx): ...
32
def __setitem__(self, *args, **kwargs): ...
33
def copy_(self, source: T, non_blocking: _Optional[bool] = None) -> T: ...
34
def new(self) -> T: ...
35
def nbytes(self) -> int: ...
37
def size(self) -> int:
40
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> T: ...
41
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ...
42
def hpu(self, device=None, non_blocking=False, **kwargs) -> T: ...
43
def element_size(self) -> int: ...
45
def get_device(self) -> int:
46
return self.device.index
48
def data_ptr(self) -> int: ...
50
def resizable(self) -> bool: ...
53
def _share_filename_cpu_(self, *args, **kwargs): ...
54
def _share_fd_cpu_(self, *args, **kwargs): ...
56
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ...
58
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ...
60
def from_buffer(cls: Type[T], *args, **kwargs) -> T: ...
62
def _new_shared_filename_cpu(cls: Type[T], manager, obj, size, *, device=None, dtype=None) -> T: ...
64
def _release_ipc_counter_cuda(cls: Type[T], *args, **kwargs) -> T: ...
66
def _new_with_weak_ptr(cls: Type[T], *args, **kwargs) -> T: ...
67
def _shared_decref(self) -> T: ...
68
def _write_file(self, *args, **kwargs): ...
69
def resize_(self, size: int): ...
70
def _weak_ref(self, *args, **kwargs) -> T: ...
71
def _set_from_file(self, *args, **kwargs): ...
72
def _set_cdata(self, *args, **kwargs): ...
73
def _share_cuda_(self, *args, **kwargs): ...
74
def is_shared(self) -> bool: ...
76
def _new_shared_cuda(cls: Type[T], *args, **kwargs) -> T: ...
77
def _shared_incref(self, *args, **kwargs): ...
79
def _free_weak_ref(cls, *args, **kwargs): ...
81
def is_cuda(self): ...
85
def from_file(cls, filename, shared, nbytes) -> T: ...
87
def _expired(cls, *args, **kwargs) -> T: ...
88
def _byteswap(self, *args, **kwargs): ...
89
def _get_filename(self, *args, **kwargs) -> _Optional[str]: ...
93
f'[{torch.typename(self)}(device={self.device}) '
94
f'of size {len(self)}]')
95
if self.device.type == 'meta':
96
return '...\n' + info_str
98
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
99
return data_str + '\n' + info_str
105
return iter(self[i] for i in range(self.size()))
110
def __deepcopy__(self, memo):
111
memo = memo.setdefault('torch', {})
112
if self._cdata in memo:
113
return memo[self._cdata]
114
new_storage = self.clone()
115
memo[self._cdata] = new_storage
118
def __reduce__(self):
120
torch.save(self, b, _use_new_zipfile_serialization=False)
121
return (_load_from_bytes, (b.getvalue(),))
123
def __sizeof__(self):
124
return super().__sizeof__() + self.size()
127
"""Return a copy of this storage."""
128
return type(self)(self.nbytes(), device=self.device).copy_(self)
131
"""Return a list containing the elements of this storage."""
135
"""Return a CPU copy of this storage if it's not already on the CPU."""
136
if self.device.type != 'cpu':
137
return torch.UntypedStorage(self.size()).copy_(self, False)
142
"""Return a MPS copy of this storage if it's not already on the MPS."""
143
if self.device.type != 'mps':
144
return torch.UntypedStorage(self.size(), device="mps").copy_(self, False)
148
def _to(self, dtype):
149
if not isinstance(dtype, torch.dtype):
150
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
151
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype)._typed_storage()
152
if storage.data_ptr() == self.data_ptr():
153
storage = storage.clone()
157
"""Casts this storage to double type."""
158
return self._to(torch.double)
161
"""Casts this storage to float type."""
162
return self._to(torch.float)
165
"""Casts this storage to half type."""
166
return self._to(torch.half)
169
"""Casts this storage to long type."""
170
return self._to(torch.long)
173
"""Casts this storage to int type."""
174
return self._to(torch.int)
177
"""Casts this storage to short type."""
178
return self._to(torch.short)
181
"""Casts this storage to char type."""
182
return self._to(torch.int8)
185
"""Casts this storage to byte type."""
186
return self._to(torch.uint8)
189
"""Casts this storage to bool type."""
190
return self._to(torch.bool)
193
"""Casts this storage to bfloat16 type."""
194
return self._to(torch.bfloat16)
196
def complex_double(self):
197
"""Casts this storage to complex double type."""
198
return self._to(torch.cdouble)
200
def complex_float(self):
201
"""Casts this storage to complex float type."""
202
return self._to(torch.cfloat)
204
def float8_e5m2(self):
205
"""Casts this storage to float8_e5m2 type"""
206
return self._to(torch.float8_e5m2)
208
def float8_e4m3fn(self):
209
"""Casts this storage to float8_e4m3fn type"""
210
return self._to(torch.float8_e4m3fn)
212
def float8_e5m2fnuz(self):
213
"""Casts this storage to float8_e5m2fnuz type"""
214
return self._to(torch.float8_e5m2fnuz)
216
def float8_e4m3fnuz(self):
217
"""Casts this storage to float8_e4m3fnuz type"""
218
return self._to(torch.float8_e4m3fnuz)
220
def is_pinned(self, device: Union[str, torch.device] = 'cuda'):
221
r"""Determine whether the CPU storage is already pinned on device.
224
device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
229
return torch.tensor([], dtype=torch.uint8, device=self.device).set_(
230
cast(Storage, self)).is_pinned(device)
232
def pin_memory(self, device: Union[str, torch.device] = 'cuda'):
233
r"""Copy the CPU storage to pinned memory, if it's not already pinned.
236
device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
239
A pinned CPU storage.
241
if self.device.type != 'cpu':
242
raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
244
pinned_tensor = torch.tensor([], dtype=torch.uint8, device=self.device).set_(
245
cast(Storage, self)).pin_memory(device)
246
return pinned_tensor.untyped_storage()
248
def share_memory_(self):
249
"""See :meth:`torch.UntypedStorage.share_memory_`"""
250
from torch.multiprocessing import get_sharing_strategy
251
if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
253
elif get_sharing_strategy() == 'file_system':
254
self._share_filename_cpu_()
256
self._share_fd_cpu_()
260
def _new_shared(cls, size, *, device='cpu'):
261
"""Create a new storage in shared memory with the same data type."""
262
from torch.multiprocessing import get_sharing_strategy
263
device = torch.device(device)
264
if device.type in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]:
265
return cls(size, device=device)
266
elif get_sharing_strategy() == 'file_system':
267
return cls._new_using_filename_cpu(size)
269
return cls._new_using_fd_cpu(size)
274
def byteswap(self, dtype):
275
"""Swap bytes in underlying data."""
276
elem_size = torch._utils._element_size(dtype)
279
elem_size = max(int(elem_size / 2), 1)
280
self._byteswap(elem_size)
283
def _share_memory_lock_protected(fn):
285
def wrapper(self, *args, **kwargs):
288
with _share_memory_lock:
290
if key in _share_memory_map:
291
to_wait = _share_memory_map[key]
293
_share_memory_map[key] = threading.RLock()
294
_share_memory_map[key].acquire()
299
if to_wait is not None:
304
return fn(self, *args, **kwargs)
308
if to_free is not None:
311
assert self._cdata == to_free
312
with _share_memory_lock:
313
_share_memory_map[to_free].release()
314
del _share_memory_map[to_free]
317
class UntypedStorage(torch._C.StorageBase, _StorageBase):
318
def __getitem__(self, *args, **kwargs):
319
if self.device.type == 'meta':
320
raise NotImplementedError("Not available for 'meta' device type")
321
return super().__getitem__(*args, **kwargs)
325
return self.device.type == 'cuda'
329
return self.device.type == 'hpu'
332
def filename(self) -> _Optional[str]:
333
"""Returns the file name associated with this storage if the storage was memory mapped from a file.
334
or ``None`` if the storage was not created by memory mapping a file."""
335
return self._get_filename()
337
@_share_memory_lock_protected
338
def share_memory_(self, *args, **kwargs):
340
Moves the storage to shared memory.
342
This is a no-op for storages already in shared memory and for CUDA
343
storages, which do not need to be moved for sharing across processes.
344
Storages in shared memory cannot be resized.
346
Note that to mitigate issues like `this <https://github.com/pytorch/pytorch/issues/95606>`_
347
it is thread safe to call this function from multiple threads on the same object.
348
It is NOT thread safe though to call any other function on self without proper
349
synchronization. Please see :doc:`/notes/multiprocessing` for more details.
352
When all references to a storage in shared memory are deleted, the associated shared memory
353
object will also be deleted. PyTorch has a special cleanup process to ensure that this happens
354
even if the current process exits unexpectedly.
356
It is worth noting the difference between :meth:`share_memory_` and :meth:`from_file` with ``shared = True``
358
#. ``share_memory_`` uses `shm_open(3) <https://man7.org/linux/man-pages/man3/shm_open.3.html>`_ to create a
359
POSIX shared memory object while :meth:`from_file` uses
360
`open(2) <https://man7.org/linux/man-pages/man2/open.2.html>`_ to open the filename passed by the user.
361
#. Both use an `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_ with ``MAP_SHARED``
362
to map the file/object into the current virtual address space
363
#. ``share_memory_`` will call ``shm_unlink(3)`` on the object after mapping it to make sure the shared memory
364
object is freed when no process has the object open. ``torch.from_file(shared=True)`` does not unlink the
365
file. This file is persistent and will remain until it is deleted by the user.
370
return super().share_memory_(*args, **kwargs)
372
@_share_memory_lock_protected
373
def _share_fd_cpu_(self, *args, **kwargs):
374
return super()._share_fd_cpu_(*args, **kwargs)
376
@_share_memory_lock_protected
377
def _share_filename_cpu_(self, *args, **kwargs):
378
return super()._share_filename_cpu_(*args, **kwargs)
380
def _load_from_bytes(b):
381
return torch.load(io.BytesIO(b))
384
_StorageBase.type = _type
385
_StorageBase.cuda = _cuda
386
_StorageBase.hpu = _hpu
389
@lru_cache(maxsize=None)
390
def _dtype_to_storage_type_map():
397
torch.double: 'DoubleStorage',
398
torch.float: 'FloatStorage',
399
torch.half: 'HalfStorage',
400
torch.long: 'LongStorage',
401
torch.int: 'IntStorage',
402
torch.int16: 'ShortStorage',
403
torch.int8: 'CharStorage',
404
torch.uint8: 'ByteStorage',
405
torch.bool: 'BoolStorage',
406
torch.bfloat16: 'BFloat16Storage',
407
torch.cdouble: 'ComplexDoubleStorage',
408
torch.cfloat: 'ComplexFloatStorage',
409
torch.qint8: 'QInt8Storage',
410
torch.qint32: 'QInt32Storage',
411
torch.quint8: 'QUInt8Storage',
412
torch.quint4x2: 'QUInt4x2Storage',
413
torch.quint2x4: 'QUInt2x4Storage',
416
@lru_cache(maxsize=None)
417
def _storage_type_to_dtype_map():
419
val: key for key, val in _dtype_to_storage_type_map().items()}
422
def _get_storage_from_sequence(sequence, dtype, device):
423
if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
425
torch.quint8: torch.uint8,
426
torch.quint4x2: torch.uint8,
427
torch.quint2x4: torch.uint8,
428
torch.qint32: torch.int32,
429
torch.qint8: torch.int8
431
tmp_tensor = torch.tensor(
433
dtype=interpret_dtypes[dtype],
437
tmp_tensor = torch.tensor(
442
return tmp_tensor._typed_storage()._untyped_storage
446
return isinstance(x, (int, np.integer))
448
return isinstance(x, int)
450
_always_warn_typed_storage_removal = False
452
def _get_always_warn_typed_storage_removal():
453
return _always_warn_typed_storage_removal
455
def _set_always_warn_typed_storage_removal(always_warn):
456
global _always_warn_typed_storage_removal
457
assert isinstance(always_warn, bool)
458
_always_warn_typed_storage_removal = always_warn
460
def _warn_typed_storage_removal(stacklevel=2):
461
global _always_warn_typed_storage_removal
464
if not hasattr(_warn_typed_storage_removal, 'has_warned'):
467
return not _warn_typed_storage_removal.__dict__['has_warned']
469
if _get_always_warn_typed_storage_removal() or is_first_time():
471
"TypedStorage is deprecated. It will be removed in the future and "
472
"UntypedStorage will be the only storage class. This should only matter "
473
"to you if you are using storages directly. To access UntypedStorage "
474
"directly, use tensor.untyped_storage() instead of tensor.storage()"
476
warnings.warn(message, UserWarning, stacklevel=stacklevel + 1)
477
_warn_typed_storage_removal.__dict__['has_warned'] = True
479
def _reset_warn_typed_storage_removal():
480
_warn_typed_storage_removal.__dict__['has_warned'] = False
482
def _get_device_from_module(module: str):
483
last_part = module.rsplit(".", 1)[-1]
484
if last_part in ["cuda", torch._C._get_privateuse1_backend_name(), "hpu"]:
499
def filename(self) -> _Optional[str]:
500
"""Returns the file name associated with this storage if the storage was memory mapped from a file.
501
or ``None`` if the storage was not created by memory mapping a file."""
502
return self._untyped_storage.filename
504
def fill_(self, value):
505
_warn_typed_storage_removal()
506
self._setitem(slice(0, self._size()), value)
509
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None, _internal=False):
511
_warn_typed_storage_removal()
513
if cls == torch.storage._LegacyStorage:
514
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
516
if cls == TypedStorage:
517
return super().__new__(cls)
521
f'{cls}.__new__ received an invalid combination '
522
f'of arguments. Expected one of:\n'
525
' * (Sequence data)\n'
526
' * (*, UntypedStorage wrap_storage)')
528
if device is not None:
531
"\nKeyword argument 'device' cannot be specified")
533
if dtype is not None:
536
"\nKeyword argument 'dtype' cannot be specified")
538
if wrap_storage is None:
542
"\nToo many positional arguments")
544
if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
547
f"\nArgument type not recognized: {type(args[0])}")
552
device=_get_device_from_module(cls.__module__),
559
"\nNo positional arguments should be given when using "
562
if not isinstance(wrap_storage, torch.UntypedStorage):
565
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
567
cls_device = _get_device_from_module(cls.__module__)
569
if wrap_storage.device.type != cls_device:
572
f"\nDevice of 'wrap_storage' must be {cls_device}"
573
f", but got {wrap_storage.device.type}")
577
wrap_storage=wrap_storage,
581
def __init__(self, *args, device=None, dtype=None, wrap_storage=None, _internal=False):
583
_warn_typed_storage_removal()
585
'TypedStorage.__init__ received an invalid combination '
586
'of arguments. Expected one of:\n'
587
' * (*, torch.device device, torch.dtype dtype)\n'
588
' * (int size, *, torch.device device, torch.dtype dtype)\n'
589
' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
590
' * (*, UntypedStorage wrap_storage, torch.dtype dtype)')
592
if wrap_storage is not None:
596
"\nNo positional arguments should be given when using "
602
"\nArgument 'dtype' must be specified")
604
if not isinstance(dtype, torch.dtype):
607
f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
609
if device is not None:
612
"\nArgument 'device' should not be specified when 'wrap_storage' is given")
616
if not isinstance(wrap_storage, torch.UntypedStorage):
619
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
621
self._untyped_storage = wrap_storage
624
self.dtype = torch.get_default_dtype() if dtype is None else dtype
625
device = torch.device('cpu' if device is None else device)
627
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
628
if device.type == 'cuda':
629
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
632
self._untyped_storage = torch.UntypedStorage(device=device)
636
self._untyped_storage = torch.UntypedStorage(int(args[0]) * self._element_size(), device=device)
637
elif isinstance(args[0], collections.abc.Sequence):
638
self._untyped_storage = _get_storage_from_sequence(args[0], self.dtype, device)
642
f"\nArgument type not recognized: {type(args[0])}")
647
"\nToo many positional arguments")
651
_warn_typed_storage_removal()
652
return self._untyped_storage.device.type == 'cuda'
656
_warn_typed_storage_removal()
657
return self._untyped_storage.device.type == 'hpu'
660
"""Return the internal :class:`torch.UntypedStorage`."""
661
_warn_typed_storage_removal()
662
return self._untyped_storage
664
def _new_wrapped_storage(self, untyped_storage):
665
assert type(untyped_storage) == torch.UntypedStorage
667
if type(self) == TypedStorage:
669
wrap_storage=untyped_storage,
673
return type(self)(wrap_storage=untyped_storage)
676
_warn_typed_storage_removal()
679
def _maybe_wrap_index(self, idx, is_stop=False):
689
f"can't index a {type(self)} with {type(idx)}")
691
if (idx > self._size()) or (idx < -self._size()):
693
f'index {idx} out of range for storage of size {self.size()}')
697
return idx % self._size()
699
if (idx >= self._size()) or (idx < -self._size()):
701
f'index {idx} out of range for storage of size {self.size()}')
702
return idx % self._size()
704
def __setitem__(self, idx, value):
705
_warn_typed_storage_removal()
706
return self._setitem(idx, value)
708
def _setitem(self, idx, value):
709
if not isinstance(idx, (int, slice)):
710
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
711
if torch.is_storage(value):
712
raise RuntimeError(f'cannot set item with value type {type(value)}')
713
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
715
torch.quint8: torch.uint8,
716
torch.quint4x2: torch.uint8,
717
torch.quint2x4: torch.uint8,
718
torch.qint32: torch.int32,
719
torch.qint8: torch.int8
721
tmp_dtype = interpret_dtypes[self.dtype]
722
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self._untyped_storage.device)
723
tmp_tensor.set_(TypedStorage(
724
wrap_storage=self._untyped_storage,
728
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
730
tmp_tensor[idx] = value
732
def __getitem__(self, idx):
733
_warn_typed_storage_removal()
734
return self._getitem(idx)
736
def _getitem(self, idx):
737
if self._untyped_storage.device.type == 'meta':
738
raise NotImplementedError("Not available for 'meta' device type")
744
if isinstance(idx, slice):
745
raise RuntimeError('slices are only supported in UntypedStorage.__getitem__')
746
elif not isinstance(idx, int):
747
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
749
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
751
torch.quint8: torch.uint8,
752
torch.quint4x2: torch.uint8,
753
torch.quint2x4: torch.uint8,
754
torch.qint32: torch.int32,
755
torch.qint8: torch.int8
758
wrap_storage=self._untyped_storage,
759
dtype=interpret_dtypes[self.dtype],
760
_internal=True)._getitem(idx)
762
idx_wrapped = self._maybe_wrap_index(idx)
763
from torch._subclasses.fake_tensor import unset_fake_temporarily
765
with unset_fake_temporarily():
766
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self._untyped_storage.device).set_(self)
767
return tmp_tensor[idx_wrapped].item()
769
def copy_(self, source: T, non_blocking: _Optional[bool] = None):
770
_warn_typed_storage_removal()
771
if isinstance(source, TypedStorage):
772
self._untyped_storage.copy_(source._untyped_storage, non_blocking)
774
self._untyped_storage.copy_(source, non_blocking)
778
_warn_typed_storage_removal()
779
return self._nbytes()
783
return self._untyped_storage.nbytes()
785
def type(self, dtype: _Optional[str] = None, non_blocking: bool = False) -> Union[T, str]:
786
_warn_typed_storage_removal()
788
legacy_class = self._get_legacy_storage_class()
790
if legacy_class is not None:
791
return legacy_class.__module__ + '.' + legacy_class.__name__
793
return '.'.join([self.__module__, type(self).__name__])
796
return self._untyped_storage.type(dtype, non_blocking)
798
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
799
_warn_typed_storage_removal()
800
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
801
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
802
cuda_storage: torch.UntypedStorage = self._untyped_storage.cuda(device, non_blocking, **kwargs)
803
return self._new_wrapped_storage(cuda_storage)
805
def hpu(self, device=None, non_blocking=False, **kwargs) -> T:
806
_warn_typed_storage_removal()
807
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
808
raise RuntimeError("Cannot create HPU storage with quantized dtype")
809
hpu_storage: torch.UntypedStorage = self._untyped_storage.hpu(device, non_blocking, **kwargs)
810
return self._new_wrapped_storage(hpu_storage)
812
def element_size(self):
813
_warn_typed_storage_removal()
814
return self._element_size()
817
def _element_size(self):
818
return torch._utils._element_size(self.dtype)
820
def get_device(self) -> int:
821
_warn_typed_storage_removal()
822
return self._untyped_storage.get_device()
825
_warn_typed_storage_removal()
827
f'[{torch.typename(self)}(dtype={self.dtype}, '
828
f'device={self.device}) of size {len(self)}]')
829
if self.device.type == 'meta':
830
return '...\n' + info_str
832
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
833
return data_str + '\n' + info_str
836
_warn_typed_storage_removal()
840
_warn_typed_storage_removal()
841
return iter(self[i] for i in range(self.size()))
844
_warn_typed_storage_removal()
845
return self._new_wrapped_storage(copy.copy(self._untyped_storage))
847
def __deepcopy__(self, memo):
848
_warn_typed_storage_removal()
849
return self._deepcopy(memo)
852
def _deepcopy(self, memo):
853
return self._new_wrapped_storage(copy.deepcopy(self._untyped_storage, memo))
855
def __sizeof__(self):
856
_warn_typed_storage_removal()
857
return super().__sizeof__() + self.nbytes()
860
"""Return a copy of this storage."""
861
_warn_typed_storage_removal()
862
return self._new_wrapped_storage(self._untyped_storage.clone())
865
"""Return a list containing the elements of this storage."""
866
_warn_typed_storage_removal()
870
"""Return a CPU copy of this storage if it's not already on the CPU."""
871
_warn_typed_storage_removal()
872
return self._new_wrapped_storage(self._untyped_storage.cpu())
874
def is_pinned(self, device: Union[str, torch.device] = 'cuda'):
875
r"""Determine whether the CPU TypedStorage is already pinned on device.
878
device (str or torch.device): The device to pin memory on. Default: ``'cuda'``
883
_warn_typed_storage_removal()
884
return self._untyped_storage.is_pinned(device)
886
def pin_memory(self, device: Union[str, torch.device] = 'cuda'):
887
r"""Copy the CPU TypedStorage to pinned memory, if it's not already pinned.
890
device (str or torch.device): The device to pin memory on. Default: ``'cuda'``.
893
A pinned CPU storage.
895
_warn_typed_storage_removal()
896
return self._new_wrapped_storage(self._untyped_storage.pin_memory(device=device))
898
def share_memory_(self):
899
"""See :meth:`torch.UntypedStorage.share_memory_`"""
900
_warn_typed_storage_removal()
901
return self._share_memory_()
904
def _share_memory_(self):
905
self._untyped_storage.share_memory_()
908
def _new_shared(self, size, *, device=None):
909
"""Create a new storage in shared memory with the same data type."""
912
device = torch.device(device)
913
untyped_storage = torch.UntypedStorage._new_shared(size * self._element_size(), device=device)
915
wrap_storage=untyped_storage,
921
return self._untyped_storage._cdata
925
_warn_typed_storage_removal()
926
return self._untyped_storage.device
929
_warn_typed_storage_removal()
936
return self._untyped_storage.nbytes() // self._element_size()
938
def pickle_storage_type(self):
939
_warn_typed_storage_removal()
940
return self._pickle_storage_type()
943
def _pickle_storage_type(self):
945
return _dtype_to_storage_type_map()[self.dtype]
946
except KeyError as e:
947
raise KeyError(f'dtype {self.dtype} is not recognized') from e
949
def __reduce__(self):
951
torch.save(self, b, _use_new_zipfile_serialization=False)
952
return (_load_from_bytes, (b.getvalue(),))
955
_warn_typed_storage_removal()
956
return self._data_ptr()
960
return self._untyped_storage.data_ptr()
963
_warn_typed_storage_removal()
964
return self._untyped_storage.resizable()
966
def resize_(self, size):
967
_warn_typed_storage_removal()
971
def _resize_(self, size):
972
self._untyped_storage.resize_(size * self._element_size())
975
def _free_weak_ref(cls, *args, **kwargs):
976
return UntypedStorage._free_weak_ref(*args, **kwargs)
978
def _weak_ref(self, *args, **kwargs):
979
return self._untyped_storage._weak_ref(*args, **kwargs)
982
def from_buffer(cls, *args, **kwargs):
983
_warn_typed_storage_removal()
984
return cls._from_buffer(*args, **kwargs)
987
def _from_buffer(cls, *args, dtype=None, device=None, **kwargs):
988
if cls == TypedStorage:
989
dtype = torch.get_default_dtype() if dtype is None else dtype
990
device = torch.device('cpu' if device is None else device)
991
if device.type != 'cpu':
992
raise RuntimeError(f'TypedStorage.from_buffer: Not available for device {device.type}')
993
untyped_storage: torch.UntypedStorage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
996
if dtype is not None or len(args) == 5:
998
"from_buffer: 'dtype' can only be specified in "
999
"UntypedStorage.from_buffer and TypedStorage.from_buffer")
1000
if device is not None:
1002
"from_buffer: 'device' can only be specified in "
1003
"UntypedStorage.from_buffer and TypedStorage.from_buffer")
1006
untyped_storage = torch.UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
1008
return TypedStorage(
1009
wrap_storage=untyped_storage,
1013
def _to(self, dtype):
1014
if not isinstance(dtype, torch.dtype):
1015
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
1016
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype)._typed_storage()
1017
if storage.data_ptr() == self.data_ptr():
1018
storage = storage.clone()
1022
"""Casts this storage to double type."""
1023
_warn_typed_storage_removal()
1024
return self._to(torch.double)
1027
"""Casts this storage to float type."""
1028
_warn_typed_storage_removal()
1029
return self._to(torch.float)
1032
"""Casts this storage to half type."""
1033
_warn_typed_storage_removal()
1034
return self._to(torch.half)
1037
"""Casts this storage to long type."""
1038
_warn_typed_storage_removal()
1039
return self._to(torch.long)
1042
"""Casts this storage to int type."""
1043
_warn_typed_storage_removal()
1044
return self._to(torch.int)
1047
"""Casts this storage to short type."""
1048
_warn_typed_storage_removal()
1049
return self._to(torch.short)
1052
"""Casts this storage to char type."""
1053
_warn_typed_storage_removal()
1054
return self._to(torch.int8)
1057
"""Casts this storage to byte type."""
1058
_warn_typed_storage_removal()
1059
return self._to(torch.uint8)
1062
"""Casts this storage to bool type."""
1063
_warn_typed_storage_removal()
1064
return self._to(torch.bool)
1067
"""Casts this storage to bfloat16 type."""
1068
_warn_typed_storage_removal()
1069
return self._to(torch.bfloat16)
1071
def complex_double(self):
1072
"""Casts this storage to complex double type."""
1073
_warn_typed_storage_removal()
1074
return self._to(torch.cdouble)
1076
def complex_float(self):
1077
"""Casts this storage to complex float type."""
1078
_warn_typed_storage_removal()
1079
return self._to(torch.cfloat)
1081
def float8_e5m2(self):
1082
"""Casts this storage to float8_e5m2 type"""
1083
_warn_typed_storage_removal()
1084
return self._to(torch.float8_e5m2)
1086
def float8_e4m3fn(self):
1087
"""Casts this storage to float8_e4m3fn type"""
1088
_warn_typed_storage_removal()
1089
return self._to(torch.float8_e4m3fn)
1091
def float8_e5m2fnuz(self):
1092
"""Casts this storage to float8_e5m2fnuz type"""
1093
_warn_typed_storage_removal()
1094
return self._to(torch.float8_e5m2fnuz)
1096
def float8_e4m3fnuz(self):
1097
"""Casts this storage to float8_e4m3fnuz type"""
1098
_warn_typed_storage_removal()
1099
return self._to(torch.float8_e4m3fnuz)
1102
def from_file(cls, filename, shared, size):
1103
"""from_file(filename, shared=False, size=0) -> Storage
1105
Creates a CPU storage backed by a memory-mapped file.
1107
If ``shared`` is ``True``, then memory is shared between all processes.
1108
All changes are written to the file. If ``shared`` is ``False``, then the changes on
1109
the storage do not affect the file.
1111
``size`` is the number of elements in the storage. If ``shared`` is ``False``,
1112
then the file must contain at least ``size * sizeof(Type)`` bytes
1113
(``Type`` is the type of storage). If ``shared`` is ``True`` the file will be created if needed.
1116
filename (str): file name to map
1117
shared (bool): whether to share memory (whether ``MAP_SHARED`` or ``MAP_PRIVATE`` is passed to the
1118
underlying `mmap(2) call <https://man7.org/linux/man-pages/man2/mmap.2.html>`_)
1119
size (int): number of elements in the storage
1121
_warn_typed_storage_removal()
1122
if cls == TypedStorage:
1123
raise RuntimeError('from_file can only be called on derived classes')
1124
untyped_storage: UntypedStorage = UntypedStorage.from_file(
1127
size * torch._utils._element_size(cls.dtype))
1128
storage = cls(wrap_storage=untyped_storage)
1132
def _expired(cls, *args, **kwargs):
1133
return UntypedStorage._expired(*args, **kwargs)
1135
def _write_file(self, *args, **kwargs):
1136
return self._untyped_storage._write_file(*args, **kwargs)
1138
def _set_from_file(self, *args, **kwargs):
1139
return self._untyped_storage._set_from_file(*args, **kwargs)
1141
def _set_cdata(self, *args, **kwargs):
1142
return self._untyped_storage._set_cdata(*args, **kwargs)
1144
def _share_cuda_(self, *args, **kwargs):
1145
return self._untyped_storage._share_cuda_(*args, **kwargs)
1147
def is_shared(self):
1148
_warn_typed_storage_removal()
1149
return self._is_shared()
1152
def _is_shared(self):
1153
return self._untyped_storage.is_shared()
1156
def _new_shared_cuda(cls, *args, **kwargs):
1157
return torch.UntypedStorage._new_shared_cuda(*args, **kwargs)
1159
def _share_filename_cpu_(self, *args, **kwargs):
1160
manager_handle, storage_handle, size = self._untyped_storage._share_filename_cpu_(*args, **kwargs)
1161
return manager_handle, storage_handle, size // self._element_size()
1163
def _shared_decref(self):
1164
self._untyped_storage._shared_decref()
1168
def _release_ipc_counter(cls, *args, device=None, **kwargs):
1169
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
1171
def _shared_incref(self, *args, **kwargs):
1172
return self._untyped_storage._shared_incref(*args, **kwargs)
1174
def _share_fd_cpu_(self, *args, **kwargs):
1175
fd, size = self._untyped_storage._share_fd_cpu_(*args, **kwargs)
1176
return fd, size // self._element_size()
1178
def _get_legacy_storage_class(self):
1179
if self.dtype not in _dtype_to_storage_type_map():
1182
storage_name = _dtype_to_storage_type_map()[self.dtype]
1184
if self.device.type not in ['cpu', 'cuda', "hpu", torch._C._get_privateuse1_backend_name()]:
1187
module = torch if self.device.type == 'cpu' else getattr(torch, self.device.type)
1190
return getattr(module, storage_name)
1191
except AttributeError:
1194
TypedStorage.type.__doc__ = _type.__doc__
1195
TypedStorage.cuda.__doc__ = _cuda.__doc__
1196
TypedStorage.hpu.__doc__ = _hpu.__doc__
1198
class _LegacyStorageMeta(type):
1201
def __instancecheck__(cls, instance):
1202
if type(instance) == TypedStorage:
1203
cls_device = _get_device_from_module(cls.__module__)
1204
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
1207
class _LegacyStorage(TypedStorage, metaclass=_LegacyStorageMeta):
1209
def _new_shared(cls, size):
1210
"""Create a new storage in shared memory with the same data type."""
1211
untyped_storage = torch.UntypedStorage._new_shared(size * cls()._element_size())
1212
return cls(wrap_storage=untyped_storage)
1215
def _release_ipc_counter(cls, *args, **kwargs):
1216
return torch.UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
1219
def _new_shared_filename(cls, manager, obj, size):
1220
bytes_size = size * torch._utils._element_size(cls.dtype)
1221
return cls(wrap_storage=torch.UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
1223
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
1225
return _storage_type_to_dtype_map()[pickle_storage_type]
1226
except KeyError as e:
1228
f'pickle storage type "{pickle_storage_type}" is not recognized') from e