pytorch

Форк
0
/
__init__.py 
68 строк · 2.3 Кб
1
import os.path as _osp
2
import torch
3

4
from .throughput_benchmark import ThroughputBenchmark
5
from .cpp_backtrace import get_cpp_backtrace
6
from .backend_registration import rename_privateuse1_backend, generate_methods_for_privateuse1_backend
7
from . import deterministic
8
from . import collect_env
9
import weakref
10
import copyreg
11

12
def set_module(obj, mod):
13
    """
14
    Set the module attribute on a python object for a given object for nicer printing
15
    """
16
    if not isinstance(mod, str):
17
        raise TypeError("The mod argument should be a string")
18
    obj.__module__ = mod
19

20
if torch._running_with_deploy():
21
    # not valid inside torch_deploy interpreter, no paths exists for frozen modules
22
    cmake_prefix_path = None
23
else:
24
    cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), 'share', 'cmake')
25

26
def swap_tensors(t1, t2):
27
    """
28
    This function swaps the content of the two Tensor objects.
29
    At a high level, this will make t1 have the content of t2 while preserving
30
    its identity.
31

32
    This will not work if t1 and t2 have different slots.
33
    """
34
    # Ensure there are no weakrefs
35
    if weakref.getweakrefs(t1):
36
        raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
37
    if weakref.getweakrefs(t2):
38
        raise RuntimeError("Cannot swap t2 because it has weakref associated with it")
39
    t1_slots = set(copyreg._slotnames(t1.__class__))  # type: ignore[attr-defined]
40
    t2_slots = set(copyreg._slotnames(t2.__class__))  # type: ignore[attr-defined]
41
    if t1_slots != t2_slots:
42
        raise RuntimeError("Cannot swap t1 and t2 if they have different slots")
43

44
    def swap_attr(name):
45
        tmp = getattr(t1, name)
46
        setattr(t1, name, (getattr(t2, name)))
47
        setattr(t2, name, tmp)
48

49
    # Swap the types
50
    # Note that this will fail if there are mismatched slots
51
    swap_attr("__class__")
52

53
    # Swap the dynamic attributes
54
    swap_attr("__dict__")
55

56
    # Swap the slots
57
    for slot in t1_slots:
58
        if hasattr(t1, slot) and hasattr(t2, slot):
59
            swap_attr(slot)
60
        elif hasattr(t1, slot):
61
            setattr(t2, slot, (getattr(t1, slot)))
62
            delattr(t1, slot)
63
        elif hasattr(t2, slot):
64
            setattr(t1, slot, (getattr(t2, slot)))
65
            delattr(t2, slot)
66

67
    # Swap the at::Tensor they point to
68
    torch._C._swap_tensor_impl(t1, t2)
69

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.