pytorch

Форк
0
/
context.py 
144 строки · 5.3 Кб
1
import functools
2
from contextlib import nullcontext
3
from typing import Any, Callable, Dict, Optional, Sequence
4

5
import torch
6

7
import torch._decomp
8
import torch._prims
9

10
import torch._refs
11
import torch._refs.nn
12
import torch._refs.nn.functional
13
import torch._refs.special
14
import torch.overrides
15

16
from torch._prims_common import torch_function_passthrough
17

18

19
@functools.lru_cache(None)
20
def torch_to_refs_map():
21
    """
22
    Mapping of torch API functions to torch._refs functions.
23
    E.g. torch_to_refs_map()[torch.add] == torch._refs.add
24
    """
25
    modules = [
26
        (torch, torch._refs),
27
        (torch.nn, torch._refs.nn),
28
        (torch.nn.functional, torch._refs.nn.functional),
29
        (torch.special, torch._refs.special),
30
        (torch.fft, torch._refs.fft),
31
        (torch.linalg, torch._refs.linalg),
32
    ]
33
    r: Dict[Any, Any] = {
34
        torch.Tensor.__invert__: torch._refs.bitwise_not,
35
        torch.Tensor.__xor__: torch._refs.bitwise_xor,
36
        torch.Tensor.__and__: torch._refs.bitwise_and,
37
        torch.Tensor.__or__: torch._refs.bitwise_or,
38
        torch.Tensor.__eq__: torch._refs.eq,
39
        torch.Tensor.__rsub__: torch._refs.rsub,
40
        torch.Tensor.__rtruediv__: torch._refs.rtruediv,
41
        torch.Tensor.__floordiv__: torch._refs.floor_divide,
42
        torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
43
        torch.Tensor.__pow__: torch._refs.pow,
44
        torch.Tensor.__rpow__: torch._refs.rpow,
45
        torch.Tensor.new_empty: torch._refs.new_empty,
46
        torch.Tensor.new_full: torch._refs.new_full,
47
        torch.Tensor.new_zeros: torch._refs.new_zeros,
48
        torch.Tensor.new_ones: torch._refs.new_ones,
49
        torch.Tensor.fill_: torch._refs.fill_,
50
        torch.Tensor.zero_: torch._refs.zero_,
51
        torch.Tensor.to: torch._refs.to,
52
        torch.Tensor.sum_to_size: torch._refs.sum_to_size,
53
        # TODO: Should these methods be mapped some other way?
54
        torch.Tensor.copy_: torch._prims.copy_to,
55
        torch.Tensor.resize: torch._prims.resize,
56
    }
57
    for mod_torch, mod_refs in modules:
58
        for s in mod_refs.__all__:  # type: ignore[attr-defined]
59
            r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
60

61
    # Support remapping torch.Tensor.foo to _refs.foo
62
    for s in dir(torch.Tensor):
63
        if s in torch._refs.__all__:
64
            r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
65

66
    # Support conversions
67
    for s in torch._refs._conversions.__all__:
68
        tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
69
        r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
70

71
    return r
72

73

74
@functools.lru_cache(None)
75
def all_prims():
76
    """
77
    Set of all prim functions, e.g., torch._prims.add in all_prims()
78
    """
79
    return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
80

81

82
class TorchRefsMode(torch.overrides.TorchFunctionMode):
83
    """
84
    Switches the interpretation of torch.* functions and Tensor methods to
85
    use PrimTorch refs in torch._refs.  (Direct calls to _refs are unaffected.)
86

87
    >>> # xdoctest: +SKIP
88
    >>> with TorchRefsMode():
89
    ...     torch.add(x, y)  # calls torch._refs.add(x, y)
90

91
    By default, this context manager will fall back on the torch.* if the
92
    ref does not exist; set strict=True to error if this occurs.
93
    If the ref exists we still would like to fall back on the torch.* sometimes,
94
    this behavior can be customized by passing a function to should_fallback_fn.
95
    """
96

97
    def __init__(
98
        self,
99
        strict=False,
100
        should_fallback_fn=lambda *_: False,
101
        prims_mode_cls=nullcontext,
102
    ):
103
        self.strict = strict
104
        self.should_fallback_fn = should_fallback_fn
105
        self.prims_mode_cls = prims_mode_cls
106

107
    def __torch_function__(
108
        self,
109
        orig_func: Callable,
110
        types: Sequence,
111
        args: Sequence[Any] = (),
112
        kwargs: Optional[Dict] = None,
113
    ):
114
        if kwargs is None:
115
            kwargs = {}
116
        # For primitive operations, run them as is without interception
117
        # Unless we are in prims_mode, in which case we want to use nvprims
118
        if orig_func in torch_function_passthrough or orig_func in all_prims():
119
            with self.prims_mode_cls():
120
                return orig_func(*args, **kwargs)
121
        mapping = torch_to_refs_map()
122
        func = mapping.get(orig_func, None)
123

124
        # For torch.ops.aten.*, use registered decompositions from torch._decomp
125
        # torch._decomp.decomposition_table provides a mapping from
126
        # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
127
        # implementations.
128
        # There're other ways to implement this functionality,
129
        # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
130
        if func is None and isinstance(orig_func, torch._ops.OpOverload):
131
            func = torch._decomp.decomposition_table.get(orig_func, None)
132

133
        if func is not None:
134
            # If the ref exists query whether we should use it or not
135
            if self.should_fallback_fn(self, orig_func, func, args, kwargs):
136
                return orig_func(*args, **kwargs)
137
            # torch calls inside func should be interpreted as refs calls
138
            with self:
139
                return func(*args, **kwargs)
140
        if self.strict:
141
            raise RuntimeError(
142
                f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
143
            )
144
        return orig_func(*args, **kwargs)
145

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.