4
from typing import Callable, List, Optional, Tuple
6
from torch.utils.jit.log_extract import (
9
run_baseline_no_fusion,
16
1. Run your script and pipe into a log file
17
PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
19
log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
21
You can also extract the list of extracted IR:
22
log_extract.py log.txt --output
24
Passing in --graphs 0 2 will only run graphs 0 and 2
30
runners: List[Tuple[str, Callable]],
31
graph_set: Optional[List[int]],
33
for i, ir in enumerate(graphs):
34
_, inputs = load_graph_and_inputs(ir)
35
if graph_set and i not in graph_set:
38
print(f"Running Graph {i}")
40
prev_runner_name = None
41
for runner in runners:
42
runner_name, runner_fn = runner
44
result = runner_fn(ir, inputs)
46
improvement = (prev_result / result - 1) * 100
48
f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
51
print(f"{runner_name} : {result:.6f} ms")
53
prev_runner_name = runner_name
55
print(f" Graph {i} failed for {runner_name} :", traceback.format_exc())
59
parser = argparse.ArgumentParser(
60
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
62
parser.add_argument("filename", help="Filename of log file")
64
"--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
70
help="DON'T benchmark nvfuser",
72
parser.set_defaults(nvfuser=False)
77
help="benchmark nnc static",
83
help="DON'T benchmark nnc static",
85
parser.set_defaults(nnc_static=False)
91
help="nnc with dynamic shapes",
97
help="DONT't benchmark nnc with dynamic shapes",
99
parser.set_defaults(nnc_dynamic=False)
102
"--baseline", dest="baseline", action="store_true", help="benchmark baseline"
107
action="store_false",
108
help="DON'T benchmark baseline",
110
parser.set_defaults(baseline=False)
113
"--output", dest="output", action="store_true", help="Output graph IR"
116
"--no-output", dest="output", action="store_false", help="DON'T output graph IR"
118
parser.set_defaults(output=False)
121
"--graphs", nargs="+", type=int, help="Run only specified graph indices"
124
args = parser.parse_args()
125
graphs = extract_ir(args.filename)
127
graph_set = args.graphs
128
graph_set = graph_set if graph_set else None
132
options.append(("Baseline no fusion", run_baseline_no_fusion))
134
options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
136
options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
138
options.append(("NVFuser", run_nvfuser))
140
test_runners(graphs, options, graph_set)
144
for i, ir in enumerate(graphs):
145
if graph_set and i not in graph_set:
147
quoted.append('"""' + ir + '"""')
148
print("[" + ", ".join(quoted) + "]")
151
if __name__ == "__main__":