pytorch

Форк
0
/
log_extract.py 
152 строки · 4.4 Кб
1
import argparse
2
import functools
3
import traceback
4
from typing import Callable, List, Optional, Tuple
5

6
from torch.utils.jit.log_extract import (
7
    extract_ir,
8
    load_graph_and_inputs,
9
    run_baseline_no_fusion,
10
    run_nnc,
11
    run_nvfuser,
12
)
13

14
"""
15
Usage:
16
1. Run your script and pipe into a log file
17
  PYTORCH_JIT_LOG_LEVEL=">>graph_fuser" python3 my_test.py &> log.txt
18
2. Run log_extract:
19
  log_extract.py log.txt --nvfuser --nnc-dynamic --nnc-static
20

21
You can also extract the list of extracted IR:
22
  log_extract.py log.txt --output
23

24
Passing in --graphs 0 2 will only run graphs 0 and 2
25
"""
26

27

28
def test_runners(
29
    graphs: List[str],
30
    runners: List[Tuple[str, Callable]],
31
    graph_set: Optional[List[int]],
32
):
33
    for i, ir in enumerate(graphs):
34
        _, inputs = load_graph_and_inputs(ir)
35
        if graph_set and i not in graph_set:
36
            continue
37

38
        print(f"Running Graph {i}")
39
        prev_result = None
40
        prev_runner_name = None
41
        for runner in runners:
42
            runner_name, runner_fn = runner
43
            try:
44
                result = runner_fn(ir, inputs)
45
                if prev_result:
46
                    improvement = (prev_result / result - 1) * 100
47
                    print(
48
                        f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%"
49
                    )
50
                else:
51
                    print(f"{runner_name} : {result:.6f} ms")
52
                prev_result = result
53
                prev_runner_name = runner_name
54
            except RuntimeError:
55
                print(f"  Graph {i} failed for {runner_name} :", traceback.format_exc())
56

57

58
def run():
59
    parser = argparse.ArgumentParser(
60
        description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
61
    )
62
    parser.add_argument("filename", help="Filename of log file")
63
    parser.add_argument(
64
        "--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser"
65
    )
66
    parser.add_argument(
67
        "--no-nvfuser",
68
        dest="nvfuser",
69
        action="store_false",
70
        help="DON'T benchmark nvfuser",
71
    )
72
    parser.set_defaults(nvfuser=False)
73
    parser.add_argument(
74
        "--nnc-static",
75
        dest="nnc_static",
76
        action="store_true",
77
        help="benchmark nnc static",
78
    )
79
    parser.add_argument(
80
        "--no-nnc-static",
81
        dest="nnc_static",
82
        action="store_false",
83
        help="DON'T benchmark nnc static",
84
    )
85
    parser.set_defaults(nnc_static=False)
86

87
    parser.add_argument(
88
        "--nnc-dynamic",
89
        dest="nnc_dynamic",
90
        action="store_true",
91
        help="nnc with dynamic shapes",
92
    )
93
    parser.add_argument(
94
        "--no-nnc-dynamic",
95
        dest="nnc_dynamic",
96
        action="store_false",
97
        help="DONT't benchmark nnc with dynamic shapes",
98
    )
99
    parser.set_defaults(nnc_dynamic=False)
100

101
    parser.add_argument(
102
        "--baseline", dest="baseline", action="store_true", help="benchmark baseline"
103
    )
104
    parser.add_argument(
105
        "--no-baseline",
106
        dest="baseline",
107
        action="store_false",
108
        help="DON'T benchmark baseline",
109
    )
110
    parser.set_defaults(baseline=False)
111

112
    parser.add_argument(
113
        "--output", dest="output", action="store_true", help="Output graph IR"
114
    )
115
    parser.add_argument(
116
        "--no-output", dest="output", action="store_false", help="DON'T output graph IR"
117
    )
118
    parser.set_defaults(output=False)
119

120
    parser.add_argument(
121
        "--graphs", nargs="+", type=int, help="Run only specified graph indices"
122
    )
123

124
    args = parser.parse_args()
125
    graphs = extract_ir(args.filename)
126

127
    graph_set = args.graphs
128
    graph_set = graph_set if graph_set else None
129

130
    options = []
131
    if args.baseline:
132
        options.append(("Baseline no fusion", run_baseline_no_fusion))
133
    if args.nnc_dynamic:
134
        options.append(("NNC Dynamic", functools.partial(run_nnc, dynamic=True)))
135
    if args.nnc_static:
136
        options.append(("NNC Static", functools.partial(run_nnc, dynamic=False)))
137
    if args.nvfuser:
138
        options.append(("NVFuser", run_nvfuser))
139

140
    test_runners(graphs, options, graph_set)
141

142
    if args.output:
143
        quoted = []
144
        for i, ir in enumerate(graphs):
145
            if graph_set and i not in graph_set:
146
                continue
147
            quoted.append('"""' + ir + '"""')
148
        print("[" + ", ".join(quoted) + "]")
149

150

151
if __name__ == "__main__":
152
    run()
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.