pytorch
145 строк · 4.8 Кб
1import logging2import os3import tempfile4from enum import Enum5from typing import Callable, cast, Dict, Iterable, List, Set6
7import torch.fx as fx8from torch.fx.passes.shape_prop import TensorMetadata9from torch.utils import _pytree as pytree10from torch.utils._pytree import tree_flatten, tree_unflatten11
12
13logger: logging.Logger = logging.getLogger("graph_utils")14
15
16class OP(str, Enum):17CALL_FUNCTION = "call_function"18CALL_MODULE = "call_module"19CALL_METHOD = "call_method"20GET_ATTR = "get_attr"21OUTPUT = "output"22PLACEHOLDER = "placeholder"23
24
25class CommType(str, Enum):26ALLREDUCE = "allreduce_"27ALLGATHER = "allgather_"28BROADCAST = "broadcast_"29REDUCESCATTER = "reduce_scatter_"30SCATTER = "scatter_"31
32
33def get_node_tensor_metadata(node: fx.Node, is_required: bool = True) -> TensorMetadata:34metadata = node.meta.get("tensor_meta", None)35if is_required and metadata is None:36raise RuntimeError(37f"Callsite expects that ``tensor_meta`` exists in ``{node.name}``, "38f"but got None instead. Node: {node.op} {node.name} {node.target}"39)40return metadata41
42
43def get_output(graph: fx.Graph) -> fx.Node:44"""Take a graphmodule and return the graph output node.45
46We traverse in reverse to expedite it, with the idea that last node should be output
47"""
48for node in reversed(graph.nodes):49if node.op == OP.OUTPUT:50return node51raise RuntimeError(f"Cannot find the output node in {graph}")52
53
54def find_node(55graph: fx.Graph, predicate: Callable, reverse_order: bool = False56) -> List[fx.Node]:57"""Take a predicate and return all the nodes in the `graph` where the predicate holds."""58nodes = cast(Iterable[fx.Node], graph.nodes)59if reverse_order:60nodes = cast(Iterable[fx.Node], iter(reversed(nodes))) # type: ignore[call-overload]61return [node for node in nodes if predicate(node)]62
63
64def is_leaf_subgraph(graph: fx.Graph, subgraph: List[fx.Node]) -> bool:65"""Ensure nodes in ``subgraph`` satisfy one of the following rules.66
671. The user of the node is in ``subgraph``.
682. The user of the node is output.
693. There are no users -- the node is a side-effect node.
70"""
71all_nodes: Set[fx.Node] = set(subgraph)72output = get_output(graph)73for node in subgraph:74for user in node.users:75if not isinstance(user, fx.Node):76continue77if user not in all_nodes and user != output:78return False79return True80
81
82def clone_subgraph(83graph: fx.Graph, subgraph: List[fx.Node], target: fx.Node84) -> List[fx.Node]:85"""Clone the given subgraph and insert it before ``target``.86
87This API currently does not support inserting after ``target``.
88"""
89all_nodes = set(subgraph)90mapping: Dict[fx.Node, fx.Node] = dict()91cloned_subgraph = []92with graph.inserting_before(target):93for node in subgraph:94cloned_node = graph.call_function(95node.target, node.args, node.kwargs, node.type96)97# TODO: there are many flatten/unflatten in IterGraph that98# can be simplified with tree_map. Will simplify this in99# a follow-up PR.100original_input = pytree.arg_tree_leaves(*node.args, **node.kwargs)101cloned_input, spec = tree_flatten((cloned_node.args, cloned_node.kwargs))102mapped_cloned_input = []103for original_input_node, cloned_input_node in zip(104original_input, cloned_input105):106if (107isinstance(original_input_node, fx.Node)108and original_input_node in all_nodes109):110assert original_input_node in mapping111mapped_cloned_input.append(mapping[original_input_node])112else:113mapped_cloned_input.append(cloned_input_node)114cloned_node.args, cloned_node.kwargs = tree_unflatten(115mapped_cloned_input, spec116)117mapping[node] = cloned_node118cloned_subgraph.append(cloned_node)119
120return cloned_subgraph121
122
123def rebuild_graph(gm: fx.GraphModule, remove_dead_code: bool = True) -> None:124"""Run the required steps to ensure production-ready graph.125
126Note - per the fx docs, elimination of dead code is not very precise.
127Hence, the flag to make this step optional.
128"""
129gm.graph.lint()130if remove_dead_code:131gm.graph.eliminate_dead_code()132gm.recompile()133
134
135def dump_graphs_to_files(graphs: Dict[str, fx.GraphModule], folder: str = "") -> str:136if not folder:137folder = tempfile.mkdtemp()138
139for prefix, gm in graphs.items():140with open(os.path.join(folder, f"{prefix}.graph"), "w") as fp:141fp.write(str(gm))142
143logger.warning("Dump graphs to %s", folder)144
145return folder146