pytorch

Форк
0
/
graph_utils.py 
145 строк · 4.8 Кб
1
import logging
2
import os
3
import tempfile
4
from enum import Enum
5
from typing import Callable, cast, Dict, Iterable, List, Set
6

7
import torch.fx as fx
8
from torch.fx.passes.shape_prop import TensorMetadata
9
from torch.utils import _pytree as pytree
10
from torch.utils._pytree import tree_flatten, tree_unflatten
11

12

13
logger: logging.Logger = logging.getLogger("graph_utils")
14

15

16
class OP(str, Enum):
17
    CALL_FUNCTION = "call_function"
18
    CALL_MODULE = "call_module"
19
    CALL_METHOD = "call_method"
20
    GET_ATTR = "get_attr"
21
    OUTPUT = "output"
22
    PLACEHOLDER = "placeholder"
23

24

25
class CommType(str, Enum):
26
    ALLREDUCE = "allreduce_"
27
    ALLGATHER = "allgather_"
28
    BROADCAST = "broadcast_"
29
    REDUCESCATTER = "reduce_scatter_"
30
    SCATTER = "scatter_"
31

32

33
def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorMetadata:
34
    metadata = node.meta.get("tensor_meta", None)
35
    if is_required and metadata is None:
36
        raise RuntimeError(
37
            f"Callsite expects that ``tensor_meta`` exists in ``{node.name}``, "
38
            f"but got None instead. Node: {node.op} {node.name} {node.target}"
39
        )
40
    return metadata
41

42

43
def get_output(graph: fx.Graph) -> fx.Node:
44
    """Take a graphmodule and return the graph output node.
45

46
    We traverse in reverse to expedite it, with the idea that last node should be output
47
    """
48
    for node in reversed(graph.nodes):
49
        if node.op == OP.OUTPUT:
50
            return node
51
    raise RuntimeError(f"Cannot find the output node in {graph}")
52

53

54
def find_node(
55
    graph: fx.Graph, predicate: Callable, reverse_order: bool = False
56
) -> List[fx.Node]:
57
    """Take a predicate and return all the nodes in the `graph` where the predicate holds."""
58
    nodes = cast(Iterable[fx.Node], graph.nodes)
59
    if reverse_order:
60
        nodes = cast(Iterable[fx.Node], iter(reversed(nodes)))  # type: ignore[call-overload]
61
    return [node for node in nodes if predicate(node)]
62

63

64
def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:
65
    """Ensure nodes in ``subgraph`` satisfy one of the following rules.
66

67
    1. The user of the node is in ``subgraph``.
68
    2. The user of the node is output.
69
    3. There are no users -- the node is a side-effect node.
70
    """
71
    all_nodes: Set[fx.Node] = set(subgraph)
72
    output = get_output(graph)
73
    for node in subgraph:
74
        for user in node.users:
75
            if not isinstance(user, fx.Node):
76
                continue
77
            if user not in all_nodes and user != output:
78
                return False
79
    return True
80

81

82
def clone_subgraph(
83
    graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node
84
) -> List[fx.Node]:
85
    """Clone the given subgraph and insert it before ``target``.
86

87
    This API currently does not support inserting after ``target``.
88
    """
89
    all_nodes = set(subgraph)
90
    mapping: Dict[fx.Node, fx.Node] = dict()
91
    cloned_subgraph = []
92
    with graph.inserting_before(target):
93
        for node in subgraph:
94
            cloned_node = graph.call_function(
95
                node.target, node.args, node.kwargs, node.type
96
            )
97
            # TODO: there are many flatten/unflatten in IterGraph that
98
            # can be simplified with tree_map. Will simplify this in
99
            # a follow-up PR.
100
            original_input = pytree.arg_tree_leaves(*node.args, **node.kwargs)
101
            cloned_input, spec = tree_flatten((cloned_node.args, cloned_node.kwargs))
102
            mapped_cloned_input = []
103
            for original_input_node, cloned_input_node in zip(
104
                original_input, cloned_input
105
            ):
106
                if (
107
                    isinstance(original_input_node, fx.Node)
108
                    and original_input_node in all_nodes
109
                ):
110
                    assert original_input_node in mapping
111
                    mapped_cloned_input.append(mapping[original_input_node])
112
                else:
113
                    mapped_cloned_input.append(cloned_input_node)
114
            cloned_node.args, cloned_node.kwargs = tree_unflatten(
115
                mapped_cloned_input, spec
116
            )
117
            mapping[node] = cloned_node
118
            cloned_subgraph.append(cloned_node)
119

120
    return cloned_subgraph
121

122

123
def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:
124
    """Run the required steps to ensure production-ready graph.
125

126
    Note - per the fx docs, elimination of dead code is not very precise.
127
    Hence, the flag to make this step optional.
128
    """
129
    gm.graph.lint()
130
    if remove_dead_code:
131
        gm.graph.eliminate_dead_code()
132
    gm.recompile()
133

134

135
def dump_graphs_to_files(graphs: Dict[str, fx.GraphModule], folder: str = "") -> str:
136
    if not folder:
137
        folder = tempfile.mkdtemp()
138

139
    for prefix, gm in graphs.items():
140
        with open(os.path.join(folder, f"{prefix}.graph"), "w") as fp:
141
            fp.write(str(gm))
142

143
    logger.warning("Dump graphs to %s", folder)
144

145
    return folder
146

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.