pytorch
1import contextlib
2from typing import Callable, List, TYPE_CHECKING
3
4if TYPE_CHECKING:
5import torch
6
7# Executed in the order they're registered
8INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
9
10
11@contextlib.contextmanager
12def intermediate_hook(fn):
13INTERMEDIATE_HOOKS.append(fn)
14try:
15yield
16finally:
17INTERMEDIATE_HOOKS.pop()
18
19
20def run_intermediate_hooks(name, val):
21global INTERMEDIATE_HOOKS
22hooks = INTERMEDIATE_HOOKS
23INTERMEDIATE_HOOKS = []
24try:
25for hook in hooks:
26hook(name, val)
27finally:
28INTERMEDIATE_HOOKS = hooks
29