4
from typing import Callable, List, Optional, Tuple
6
from torch.utils.jit.log_extract import (
9
run_baseline_no_fusion,
17
1. Run your script and pipe into a log file
18
PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
20
log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
22
You can also extract the list of extracted IR:
23
log_extract.py log.txt --output
25
Passing in --graphs 0 2 will only run graphs 0 and 2
31
runners: List[Tuple[str, Callable]],
32
graph_set: Optional[List[int]],
34
for i, ir in enumerate(graphs):
35
_, inputs = load_graph_and_inputs(ir)
36
if graph_set and i not in graph_set:
39
print(f"Running Graph {i}")
41
prev_runner_name = None
42
for runner in runners:
43
runner_name, runner_fn = runner
45
result = runner_fn(ir, inputs)
47
improvement = (prev_result / result - 1) * 100
49
f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
52
print(f"{runner_name} : {result:.6f} ms")
54
prev_runner_name = runner_name
56
print(f" Graph {i} failed for {runner_name} :", traceback.format_exc())
60
parser = argparse.ArgumentParser(
61
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
63
parser.add_argument("filename", help="Filename of log file")
65
"--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
71
help="DON'T benchmark nvfuser",
73
parser.set_defaults(nvfuser=False)
78
help="benchmark nnc static",
84
help="DON'T benchmark nnc static",
86
parser.set_defaults(nnc_static=False)
92
help="nnc with dynamic shapes",
98
help="DONT't benchmark nnc with dynamic shapes",
100
parser.set_defaults(nnc_dynamic=False)
103
"--baseline", dest="baseline", action="store_true", help="benchmark baseline"
108
action="store_false",
109
help="DON'T benchmark baseline",
111
parser.set_defaults(baseline=False)
114
"--output", dest="output", action="store_true", help="Output graph IR"
117
"--no-output", dest="output", action="store_false", help="DON'T output graph IR"
119
parser.set_defaults(output=False)
122
"--graphs", nargs="+", type=int, help="Run only specified graph indices"
125
args = parser.parse_args()
126
graphs = extract_ir(args.filename)
128
graph_set = args.graphs
129
graph_set = graph_set if graph_set else None
133
options.append(("Baseline no fusion", run_baseline_no_fusion))
135
options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
137
options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
139
options.append(("NVFuser", run_nvfuser))
141
test_runners(graphs, options, graph_set)
145
for i, ir in enumerate(graphs):
146
if graph_set and i not in graph_set:
148
quoted.append('"""' + ir + '"""')
149
print("[" + ", ".join(quoted) + "]")
152
if __name__ == "__main__":