1
# mypy: allow-untyped-defs
2
from typing import Callable, Optional
4
from torch._prims.context import TorchRefsMode
5
from torch.fx import GraphModule
6
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
12
executor: str = "aten",
13
executor_parameters: Optional[dict] = None,
16
Prototype ATen executor.
18
Just executes the context's graph.
21
if executor == "aten":
22
return gm.forward(*args)
24
msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
28
def make_traced(fn: Callable):
30
Returns a function that, when called, will
31
trace its torch operations to prims and then
32
execute those prims on the requested trace executor
33
(possibly lowering them to that trace executor first).
35
Only supports the torch operations defined in _torch_to_reference_map
36
in context.py and operations with positional args. All args must
38
In the near future all these restrictions will be lifted.
43
return torch.add(a, b)
45
traced_foo = make_traced(foo)
47
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
48
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
49
result = traced_foo(a, b, executor='aten')
52
def _traced(*args, executor="aten", **kwargs):
54
wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
57
gm = make_fx(wrapped)(all_args)
58
return execute(gm, all_args, executor=executor)