pytorch
294 строки · 10.1 Кб
1# mypy: allow-untyped-defs
2import contextlib
3from abc import ABC, abstractmethod
4from typing import Any, List, Tuple
5
6import torch
7import torch.utils._pytree as pytree
8from torch._C._functorch import (
9CFunctionalizeInterpreterPtr,
10CGradInterpreterPtr,
11CInterpreter,
12CJvpInterpreterPtr,
13CVmapInterpreterPtr,
14pop_dynamic_layer_stack,
15push_dynamic_layer_stack,
16RandomnessType,
17TransformType,
18)
19from torch.autograd.forward_ad import _set_fwd_grad_enabled
20
21
22"""
23This file contains the functorch integration with PyDispatcher.
24
25PyDispatcher does not understand functorch's DynamicLayerStack dispatching
26logic because it is entirely implemented in C++ in the fallbacks for two
27dispatch keys, FuncTorchDynamicLayer{Front, Back}Mode (PyDispatcher is unable
28to directly reuse C++ boxed fallbacks).
29
30Instead of trying to hammer PyDispatcher into understanding those fallbacks,
31we re-implement the logic of peeking the top of the stack for an interpreter,
32selecting the interpreter to dispatch on, etc, in Python. This leads to a
33simpler design.
34
35The main difference between C++ functorch and PyDispatcher's functorch logic
36is that:
37- C++ functorch needs to manually tweak dispatch keys to ping-pong between
38DynamicLayerFrontMode and DynamicLayerBackMode.
39- PyDispatcher's functorch logic pops an Interpreter from the top of the stack
40and asks it to execute the rule associated with the Interpreter.
41
42In C++ we do the ping-pong because e.g. vmap rules are associated with the
43batched DispatchKey, but in PyDispatcher we are able to avoid this by asking
44the user to register a batching rule directly to a transform that an
45interpreter then invokes.
46"""
47
48
49# FuncTorchInterpreter is the Python version of Interpreter (recall that
50# the DynamicLayerStack is a stack of interpreters).
51# It is a wrapper around the actual C++ Interpreter object.
52#
53# Keep the methods in sync with aten/src/ATen/functorch/Interpreter.h
54class FuncTorchInterpreter(ABC):
55def __init__(self, cptr: Any):
56self._cptr = cptr
57
58# Process an operation. eg for vmap, this is invoking a batching rule.
59# Conceptually this is analogous to Interpreter::process in C++
60@abstractmethod
61def process(self, op, args, kwargs):
62pass
63
64# lower an operation from this Interpreter to the next Interpreter on the stack.
65# Concretely, this involves temporarily popping the current Interpreter.
66# Conceptually this is analogous to Interpreter::sendToNextInterpreter in C++
67def lower(self):
68return temporarily_pop_interpreter_stack()
69
70def level(self):
71return self._cptr.level()
72
73def key(self):
74return self._cptr.key()
75
76def get_state(self):
77raise NotImplementedError
78
79def check_state(self, state):
80return state == self.get_state()
81
82
83@contextlib.contextmanager
84def temporarily_pop_interpreter_stack():
85try:
86saved = pop_dynamic_layer_stack()
87yield
88finally:
89push_dynamic_layer_stack(saved)
90
91
92@contextlib.contextmanager
93def temporarily_clear_interpreter_stack():
94stack = []
95try:
96while torch._C._functorch.peek_interpreter_stack() is not None:
97stack.append(pop_dynamic_layer_stack())
98yield list(stack)
99finally:
100while stack:
101push_dynamic_layer_stack(stack.pop())
102
103
104@contextlib.contextmanager
105def temporarily_restore_interpreter_stack(stack):
106pushed = []
107try:
108for s in reversed(stack):
109push_dynamic_layer_stack(s)
110pushed.append(s)
111yield
112finally:
113for s in reversed(pushed):
114# TODO: would be nice to assert that the layers are the same, but
115# Python object identity is not preserved
116pop_dynamic_layer_stack()
117
118
119class VmapInterpreter(FuncTorchInterpreter):
120def __init__(self, cdata: CInterpreter):
121assert cdata.key() == TransformType.Vmap
122# NOTE: [Interpreter cdata vs cptr]
123# cdata is a generic CInterpreter. We wrap it in a CVmapInterpreterPtr
124# so that we can access methods specific to the vmap interpreter
125self._cdata = cdata
126self._cptr = CVmapInterpreterPtr(cdata)
127
128def process(self, op, args, kwargs):
129kernel = op.functorch_table[TransformType.Vmap]
130return kernel(self, *args, **kwargs)
131
132def batch_size(self):
133return self._cptr.batchSize()
134
135def randomness(self):
136typ = self._cptr.randomness()
137if typ == RandomnessType.Error:
138return "error"
139elif typ == RandomnessType.Same:
140return "same"
141elif typ == RandomnessType.Different:
142return "different"
143raise RuntimeError(f"Unknown RandomnessType: {typ}")
144
145def get_state(self):
146return (self.key().name, self.level(), self.randomness())
147
148
149@contextlib.contextmanager
150def nested(*contexts):
151with contextlib.ExitStack() as stack:
152for ctx in contexts:
153stack.enter_context(ctx)
154yield contexts
155
156
157class GradInterpreter(FuncTorchInterpreter):
158def __init__(self, cdata: CInterpreter):
159assert cdata.key() == TransformType.Grad
160# See NOTE: [Interpreter cdata vs cptr]
161self._cdata = cdata
162self._cptr = CGradInterpreterPtr(cdata)
163
164def lift(self, args, kwargs):
165args, kwargs = pytree.tree_map_only(
166torch.Tensor, self._cptr.lift, [args, kwargs]
167)
168return args, kwargs
169
170def process(self, op, args, kwargs):
171kernel = op.functorch_table[TransformType.Grad]
172args, kwargs = self.lift(args, kwargs)
173return kernel(self, *args, **kwargs)
174
175# GradInterpreter has custom lower because of the no_grad interaction
176# See NOTE [grad and vjp interaction with no_grad]
177# This logic is mirrored from C++ GradInterpreterPtr::sendToNextInterpreter
178def lower(self):
179prev_grad_mode = self.prev_grad_mode()
180if not prev_grad_mode:
181return nested(torch.no_grad(), super().lower())
182return super().lower()
183
184def prev_grad_mode(self):
185return self._cptr.prevGradMode()
186
187def get_state(self):
188return (self.key().name, self.level(), self.prev_grad_mode())
189
190
191class JvpInterpreter(FuncTorchInterpreter):
192def __init__(self, cdata: CInterpreter):
193assert cdata.key() == TransformType.Jvp
194# See NOTE: [Interpreter cdata vs cptr]
195self._cdata = cdata
196self._cptr = CJvpInterpreterPtr(cdata)
197
198def lift(self, args, kwargs):
199args, kwargs = pytree.tree_map_only(
200torch.Tensor, self._cptr.lift, [args, kwargs]
201)
202return args, kwargs
203
204def process(self, op, args, kwargs):
205kernel = op.functorch_table[TransformType.Jvp]
206args, kwargs = self.lift(args, kwargs)
207return kernel(self, *args, **kwargs)
208
209# Jvp has custom lower because of the no_fwd_grad interaction
210# See NOTE [grad and vjp interaction with no_grad] for related info.
211# This logic is mirrored from C++ JvpInterpreterPtr::sendToNextInterpreter
212def lower(self):
213prev_fwd_grad_mode = self.prev_fwd_grad_mode()
214if not prev_fwd_grad_mode:
215return nested(_set_fwd_grad_enabled(False), super().lower())
216return super().lower()
217
218def prev_fwd_grad_mode(self):
219return self._cptr.prevFwdGradMode()
220
221def get_state(self):
222return (self.key().name, self.level(), self.prev_fwd_grad_mode())
223
224
225class FunctionalizeInterpreter(FuncTorchInterpreter):
226def __init__(self, cdata: CInterpreter):
227assert cdata.key() == TransformType.Functionalize
228self._cdata = cdata
229self._cptr = CFunctionalizeInterpreterPtr(cdata)
230
231def process(self, op, args, kwargs):
232kernel = op.functorch_table[TransformType.Functionalize]
233return kernel(self, *args, **kwargs)
234
235def functionalize_add_back_views(self):
236return self._cptr.functionalizeAddBackViews()
237
238def get_state(self):
239return (self.key().name, self.level())
240
241
242def coerce_cinterpreter(cinterpreter: CInterpreter) -> FuncTorchInterpreter:
243key = cinterpreter.key()
244if key == TransformType.Grad:
245return GradInterpreter(cinterpreter)
246if key == TransformType.Vmap:
247return VmapInterpreter(cinterpreter)
248if key == TransformType.Jvp:
249return JvpInterpreter(cinterpreter)
250if key == TransformType.Functionalize:
251return FunctionalizeInterpreter(cinterpreter)
252raise RuntimeError(f"NYI: PyDispatcher has not implemented support for {key}")
253
254
255def retrieve_current_functorch_interpreter() -> FuncTorchInterpreter:
256interpreter = torch._C._functorch.peek_interpreter_stack()
257assert interpreter is not None
258return coerce_cinterpreter(interpreter)
259
260
261def retrieve_all_functorch_interpreters() -> List[FuncTorchInterpreter]:
262cis = torch._C._functorch.get_interpreter_stack()
263if cis is None:
264return []
265return [coerce_cinterpreter(ci) for ci in cis]
266
267
268def compare_functorch_state(states: List[Tuple[Any, ...]]) -> bool:
269# There are four possible cases covered here:
270# 1. Current stack empty AND stack when generated not empty -> Invalidate
271# 2. Current stack not empty AND stack when generated empty -> Invalidate
272# 3. Current stack and generated stack empty -> Valid FX graph
273# 4. Current stack and generated stack not empty -> Valid if both states match
274peek = torch._C._functorch.peek_interpreter_stack()
275if (peek is None and len(states) != 0) or (peek is not None and len(states) == 0):
276return False
277
278cis = retrieve_all_functorch_interpreters()
279return len(cis) == len(states) and all(
280ci.check_state(state) for ci, state in zip(cis, states)
281)
282
283
284def dispatch_functorch(op, args, kwargs):
285interpreter = retrieve_current_functorch_interpreter()
286# In traditional PyTorch operators, DispatchKey::FuncTorchTensorWrapper's
287# unwrap_dead_tensors fallback handles unwrapping dead tensor wrappers.
288# PyDispatcher sidesteps the PyTorch dispatcher when dealing with functorch
289# transforms, so we manually unwrap the dead tensors here.
290# This logic won't need to exist when we have mode-only functorch.
291args, kwargs = pytree.tree_map_only(
292torch.Tensor, torch._C._functorch.unwrap_if_dead, (args, kwargs)
293)
294return interpreter.process(op, args, kwargs)
295