pytorch

Форк
0
/
pyfunctorch.py 
294 строки · 10.1 Кб
1
# mypy: allow-untyped-defs
2
import contextlib
3
from abc import ABC, abstractmethod
4
from typing import Any, List, Tuple
5

6
import torch
7
import torch.utils._pytree as pytree
8
from torch._C._functorch import (
9
    CFunctionalizeInterpreterPtr,
10
    CGradInterpreterPtr,
11
    CInterpreter,
12
    CJvpInterpreterPtr,
13
    CVmapInterpreterPtr,
14
    pop_dynamic_layer_stack,
15
    push_dynamic_layer_stack,
16
    RandomnessType,
17
    TransformType,
18
)
19
from torch.autograd.forward_ad import _set_fwd_grad_enabled
20

21

22
"""
23
This file contains the functorch integration with PyDispatcher.
24

25
PyDispatcher does not understand functorch's DynamicLayerStack dispatching
26
logic because it is entirely implemented in C++ in the fallbacks for two
27
dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
28
to directly reuse C++ boxed fallbacks).
29

30
Instead of trying to hammer PyDispatcher into understanding those fallbacks,
31
we re-implement the logic of peeking the top of the stack for an interpreter,
32
selecting the interpreter to dispatch on, etc, in Python. This leads to a
33
simpler design.
34

35
The main difference between C++ functorch and PyDispatcher's functorch logic
36
is that:
37
- C++ functorch needs to manually tweak dispatch keys to ping-pong between
38
  DynamicLayerFrontMode and DynamicLayerBackMode.
39
- PyDispatcher's functorch logic pops an Interpreter from the top of the stack
40
  and asks it to execute the rule associated with the Interpreter.
41

42
In C++ we do the ping-pong because e.g. vmap rules are associated with the
43
batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
44
the user to register a batching rule directly to a transform that an
45
interpreter then invokes.
46
"""
47

48

49
# FuncTorchInterpreter is the Python version of Interpreter (recall that
50
# the DynamicLayerStack is a stack of interpreters).
51
# It is a wrapper around the actual C++ Interpreter object.
52
#
53
# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
54
class FuncTorchInterpreter(ABC):
55
    def __init__(self, cptr: Any):
56
        self._cptr = cptr
57

58
    # Process an operation. eg for vmap, this is invoking a batching rule.
59
    # Conceptually this is analogous to Interpreter::process in C++
60
    @abstractmethod
61
    def process(self, op, args, kwargs):
62
        pass
63

64
    # lower an operation from this Interpreter to the next Interpreter on the stack.
65
    # Concretely, this involves temporarily popping the current Interpreter.
66
    # Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
67
    def lower(self):
68
        return temporarily_pop_interpreter_stack()
69

70
    def level(self):
71
        return self._cptr.level()
72

73
    def key(self):
74
        return self._cptr.key()
75

76
    def get_state(self):
77
        raise NotImplementedError
78

79
    def check_state(self, state):
80
        return state == self.get_state()
81

82

83
@contextlib.contextmanager
84
def temporarily_pop_interpreter_stack():
85
    try:
86
        saved = pop_dynamic_layer_stack()
87
        yield
88
    finally:
89
        push_dynamic_layer_stack(saved)
90

91

92
@contextlib.contextmanager
93
def temporarily_clear_interpreter_stack():
94
    stack = []
95
    try:
96
        while torch._C._functorch.peek_interpreter_stack() is not None:
97
            stack.append(pop_dynamic_layer_stack())
98
        yield list(stack)
99
    finally:
100
        while stack:
101
            push_dynamic_layer_stack(stack.pop())
102

103

104
@contextlib.contextmanager
105
def temporarily_restore_interpreter_stack(stack):
106
    pushed = []
107
    try:
108
        for s in reversed(stack):
109
            push_dynamic_layer_stack(s)
110
            pushed.append(s)
111
        yield
112
    finally:
113
        for s in reversed(pushed):
114
            # TODO: would be nice to assert that the layers are the same, but
115
            # Python object identity is not preserved
116
            pop_dynamic_layer_stack()
117

118

119
class VmapInterpreter(FuncTorchInterpreter):
120
    def __init__(self, cdata: CInterpreter):
121
        assert cdata.key() == TransformType.Vmap
122
        # NOTE: [Interpreter cdata vs cptr]
123
        # cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
124
        # so that we can access methods specific to the vmap interpreter
125
        self._cdata = cdata
126
        self._cptr = CVmapInterpreterPtr(cdata)
127

128
    def process(self, op, args, kwargs):
129
        kernel = op.functorch_table[TransformType.Vmap]
130
        return kernel(self, *args, **kwargs)
131

132
    def batch_size(self):
133
        return self._cptr.batchSize()
134

135
    def randomness(self):
136
        typ = self._cptr.randomness()
137
        if typ == RandomnessType.Error:
138
            return "error"
139
        elif typ == RandomnessType.Same:
140
            return "same"
141
        elif typ == RandomnessType.Different:
142
            return "different"
143
        raise RuntimeError(f"Unknown RandomnessType: {typ}")
144

145
    def get_state(self):
146
        return (self.key().name, self.level(), self.randomness())
147

148

149
@contextlib.contextmanager
150
def nested(*contexts):
151
    with contextlib.ExitStack() as stack:
152
        for ctx in contexts:
153
            stack.enter_context(ctx)
154
        yield contexts
155

156

157
class GradInterpreter(FuncTorchInterpreter):
158
    def __init__(self, cdata: CInterpreter):
159
        assert cdata.key() == TransformType.Grad
160
        # See NOTE: [Interpreter cdata vs cptr]
161
        self._cdata = cdata
162
        self._cptr = CGradInterpreterPtr(cdata)
163

164
    def lift(self, args, kwargs):
165
        args, kwargs = pytree.tree_map_only(
166
            torch.Tensor, self._cptr.lift, [args, kwargs]
167
        )
168
        return args, kwargs
169

170
    def process(self, op, args, kwargs):
171
        kernel = op.functorch_table[TransformType.Grad]
172
        args, kwargs = self.lift(args, kwargs)
173
        return kernel(self, *args, **kwargs)
174

175
    # GradInterpreter has custom lower because of the no_grad interaction
176
    # See NOTE [grad and vjp interaction with no_grad]
177
    # This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
178
    def lower(self):
179
        prev_grad_mode = self.prev_grad_mode()
180
        if not prev_grad_mode:
181
            return nested(torch.no_grad(), super().lower())
182
        return super().lower()
183

184
    def prev_grad_mode(self):
185
        return self._cptr.prevGradMode()
186

187
    def get_state(self):
188
        return (self.key().name, self.level(), self.prev_grad_mode())
189

190

191
class JvpInterpreter(FuncTorchInterpreter):
192
    def __init__(self, cdata: CInterpreter):
193
        assert cdata.key() == TransformType.Jvp
194
        # See NOTE: [Interpreter cdata vs cptr]
195
        self._cdata = cdata
196
        self._cptr = CJvpInterpreterPtr(cdata)
197

198
    def lift(self, args, kwargs):
199
        args, kwargs = pytree.tree_map_only(
200
            torch.Tensor, self._cptr.lift, [args, kwargs]
201
        )
202
        return args, kwargs
203

204
    def process(self, op, args, kwargs):
205
        kernel = op.functorch_table[TransformType.Jvp]
206
        args, kwargs = self.lift(args, kwargs)
207
        return kernel(self, *args, **kwargs)
208

209
    # Jvp has custom lower because of the no_fwd_grad interaction
210
    # See NOTE [grad and vjp interaction with no_grad] for related info.
211
    # This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
212
    def lower(self):
213
        prev_fwd_grad_mode = self.prev_fwd_grad_mode()
214
        if not prev_fwd_grad_mode:
215
            return nested(_set_fwd_grad_enabled(False), super().lower())
216
        return super().lower()
217

218
    def prev_fwd_grad_mode(self):
219
        return self._cptr.prevFwdGradMode()
220

221
    def get_state(self):
222
        return (self.key().name, self.level(), self.prev_fwd_grad_mode())
223

224

225
class FunctionalizeInterpreter(FuncTorchInterpreter):
226
    def __init__(self, cdata: CInterpreter):
227
        assert cdata.key() == TransformType.Functionalize
228
        self._cdata = cdata
229
        self._cptr = CFunctionalizeInterpreterPtr(cdata)
230

231
    def process(self, op, args, kwargs):
232
        kernel = op.functorch_table[TransformType.Functionalize]
233
        return kernel(self, *args, **kwargs)
234

235
    def functionalize_add_back_views(self):
236
        return self._cptr.functionalizeAddBackViews()
237

238
    def get_state(self):
239
        return (self.key().name, self.level())
240

241

242
def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
243
    key = cinterpreter.key()
244
    if key == TransformType.Grad:
245
        return GradInterpreter(cinterpreter)
246
    if key == TransformType.Vmap:
247
        return VmapInterpreter(cinterpreter)
248
    if key == TransformType.Jvp:
249
        return JvpInterpreter(cinterpreter)
250
    if key == TransformType.Functionalize:
251
        return FunctionalizeInterpreter(cinterpreter)
252
    raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
253

254

255
def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
256
    interpreter = torch._C._functorch.peek_interpreter_stack()
257
    assert interpreter is not None
258
    return coerce_cinterpreter(interpreter)
259

260

261
def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
262
    cis = torch._C._functorch.get_interpreter_stack()
263
    if cis is None:
264
        return []
265
    return [coerce_cinterpreter(ci) for ci in cis]
266

267

268
def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
269
    # There are four possible cases covered here:
270
    # 1. Current stack empty AND stack when generated not empty -> Invalidate
271
    # 2. Current stack not empty AND stack when generated empty -> Invalidate
272
    # 3. Current stack and generated stack empty -> Valid FX graph
273
    # 4. Current stack and generated stack not empty -> Valid if both states match
274
    peek = torch._C._functorch.peek_interpreter_stack()
275
    if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0):
276
        return False
277

278
    cis = retrieve_all_functorch_interpreters()
279
    return len(cis) == len(states) and all(
280
        ci.check_state(state) for ci, state in zip(cis, states)
281
    )
282

283

284
def dispatch_functorch(op, args, kwargs):
285
    interpreter = retrieve_current_functorch_interpreter()
286
    # In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
287
    # unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
288
    # PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
289
    # transforms, so we manually unwrap the dead tensors here.
290
    # This logic won't need to exist when we have mode-only functorch.
291
    args, kwargs = pytree.tree_map_only(
292
        torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
293
    )
294
    return interpreter.process(op, args, kwargs)
295

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.