pytorch

Форк
0
/
__init__.py 
130 строк · 4.2 Кб
1
r"""
2
This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python.
3
Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased
4
performance can be achieved, by running work on the metal GPU(s).
5
See https://developer.apple.com/documentation/metalperformanceshaders for more details.
6
"""
7
import torch
8
from .. import Tensor
9

10
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
11
_default_mps_generator: torch._C.Generator = None  # type: ignore[assignment]
12

13

14
# local helper function (not public or exported)
15
def _get_default_mps_generator() -> torch._C.Generator:
16
    global _default_mps_generator
17
    if _default_mps_generator is None:
18
        _default_mps_generator = torch._C._mps_get_default_generator()
19
    return _default_mps_generator
20

21

22
def synchronize() -> None:
23
    r"""Waits for all kernels in all streams on a MPS device to complete."""
24
    return torch._C._mps_deviceSynchronize()
25

26

27
def get_rng_state() -> Tensor:
28
    r"""Returns the random number generator state as a ByteTensor."""
29
    return _get_default_mps_generator().get_state()
30

31

32
def set_rng_state(new_state: Tensor) -> None:
33
    r"""Sets the random number generator state.
34

35
    Args:
36
        new_state (torch.ByteTensor): The desired state
37
    """
38
    new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
39
    _get_default_mps_generator().set_state(new_state_copy)
40

41

42
def manual_seed(seed: int) -> None:
43
    r"""Sets the seed for generating random numbers.
44

45
    Args:
46
        seed (int): The desired seed.
47
    """
48
    # the torch.mps.manual_seed() can be called from the global
49
    # torch.manual_seed() in torch/random.py. So we need to make
50
    # sure mps is available (otherwise we just return without
51
    # erroring out)
52
    if not torch._C._has_mps:
53
        return
54
    seed = int(seed)
55
    _get_default_mps_generator().manual_seed(seed)
56

57

58
def seed() -> None:
59
    r"""Sets the seed for generating random numbers to a random number."""
60
    _get_default_mps_generator().seed()
61

62

63
def empty_cache() -> None:
64
    r"""Releases all unoccupied cached memory currently held by the caching
65
    allocator so that those can be used in other GPU applications.
66
    """
67
    torch._C._mps_emptyCache()
68

69

70
def set_per_process_memory_fraction(fraction) -> None:
71
    r"""Set memory fraction for limiting process's memory allocation on MPS device.
72
    The allowed value equals the fraction multiplied by recommended maximum device memory
73
    (obtained from Metal API device.recommendedMaxWorkingSetSize).
74
    If trying to allocate more than the allowed value in a process, it will raise an out of
75
    memory error in allocator.
76

77
    Args:
78
        fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
79

80
    .. note::
81
       Passing 0 to fraction means unlimited allocations
82
       (may cause system failure if out of memory).
83
       Passing fraction greater than 1.0 allows limits beyond the value
84
       returned from device.recommendedMaxWorkingSetSize.
85
    """
86

87
    if not isinstance(fraction, float):
88
        raise TypeError("Invalid type for fraction argument, must be `float`")
89
    if fraction < 0 or fraction > 2:
90
        raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2")
91

92
    torch._C._mps_setMemoryFraction(fraction)
93

94

95
def current_allocated_memory() -> int:
96
    r"""Returns the current GPU memory occupied by tensors in bytes.
97

98
    .. note::
99
       The returned size does not include cached allocations in
100
       memory pools of MPSAllocator.
101
    """
102
    return torch._C._mps_currentAllocatedMemory()
103

104

105
def driver_allocated_memory() -> int:
106
    r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
107

108
    .. note::
109
       The returned size includes cached allocations in MPSAllocator pools
110
       as well as allocations from MPS/MPSGraph frameworks.
111
    """
112
    return torch._C._mps_driverAllocatedMemory()
113

114

115
from . import profiler
116
from .event import Event
117

118
__all__ = [
119
    "get_rng_state",
120
    "manual_seed",
121
    "seed",
122
    "set_rng_state",
123
    "synchronize",
124
    "empty_cache",
125
    "set_per_process_memory_fraction",
126
    "current_allocated_memory",
127
    "driver_allocated_memory",
128
    "Event",
129
    "profiler",
130
]
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.