pytorch

Форк
0
/
executor.py 
60 строк · 1.6 Кб
1
# mypy: allow-untyped-defs
2
from typing import Callable, Optional
3

4
from torch._prims.context import TorchRefsMode
5
from torch.fx import GraphModule
6
from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
7

8

9
def execute(
10
    gm: GraphModule,
11
    *args,
12
    executor: str = "aten",
13
    executor_parameters: Optional[dict] = None,
14
):
15
    """
16
    Prototype ATen executor.
17

18
    Just executes the context's graph.
19
    """
20

21
    if executor == "aten":
22
        return gm.forward(*args)
23

24
    msg = f"Received unexpected value for 'executor': {executor}. Allowed values are: aten."
25
    raise ValueError(msg)
26

27

28
def make_traced(fn: Callable):
29
    """
30
    Returns a function that, when called, will
31
    trace its torch operations to prims and then
32
    execute those prims on the requested trace executor
33
    (possibly lowering them to that trace executor first).
34

35
    Only supports the torch operations defined in _torch_to_reference_map
36
    in context.py and operations with positional args. All args must
37
    be tensors.
38
    In the near future all these restrictions will be lifted.
39

40
    Example usage:
41

42
    def foo(a, b):
43
      return torch.add(a, b)
44

45
    traced_foo = make_traced(foo)
46

47
    a = torch.randn((1, 2, 3, 4, 5), device='cuda')
48
    b = torch.randn((1, 2, 3, 4, 5), device='cuda')
49
    result = traced_foo(a, b, executor='aten')
50
    """
51

52
    def _traced(*args, executor="aten", **kwargs):
53
        # TODO: caching
54
        wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
55

56
        with TorchRefsMode():
57
            gm = make_fx(wrapped)(all_args)
58
        return execute(gm, all_args, executor=executor)
59

60
    return _traced
61

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.